SurfSense/surfsense_backend/app/services/notion_mcp/adapter.py
CREDO23 d6e605fd50 feat(notion-mcp): add OAuth + PKCE service layer and MCP adapter
Implements Notion MCP integration core:
- OAuth 2.0 discovery (RFC 9470 + 8414), dynamic client registration,
  PKCE token exchange, and refresh with rotation
- NotionMCPAdapter connecting to mcp.notion.com/mcp with fallback
  to direct API on known serialization errors
- Response parser translating MCP text responses into dicts matching
  NotionHistoryConnector output format
- has_mcp_notion_connector() helper for connector gating
2026-04-20 20:59:17 +02:00

253 lines
9 KiB
Python

"""Notion MCP Adapter.
Connects to Notion's hosted MCP server at ``https://mcp.notion.com/mcp``
and exposes the same method signatures as ``NotionHistoryConnector``'s
write operations so that tool factories can swap with a one-line change.
Includes an optional fallback to ``NotionHistoryConnector`` when the MCP
server returns known serialization errors (GitHub issues #215, #216).
"""
import logging
from datetime import UTC, datetime
from typing import Any
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import SearchSourceConnector
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
from app.utils.oauth_security import TokenEncryption
from .response_parser import (
extract_text_from_mcp_response,
is_mcp_serialization_error,
parse_create_page_response,
parse_delete_page_response,
parse_fetch_page_response,
parse_health_check_response,
parse_update_page_response,
)
logger = logging.getLogger(__name__)
NOTION_MCP_URL = "https://mcp.notion.com/mcp"
class NotionMCPAdapter:
"""Routes Notion operations through the hosted MCP server.
Drop-in replacement for ``NotionHistoryConnector`` write methods.
Returns the same dict structure so KB sync works unchanged.
"""
def __init__(self, session: AsyncSession, connector_id: int):
self._session = session
self._connector_id = connector_id
self._access_token: str | None = None
async def _get_valid_token(self) -> str:
"""Get a valid MCP access token, refreshing if expired."""
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise ValueError(f"Connector {self._connector_id} not found")
cfg = connector.config or {}
if not cfg.get("mcp_mode"):
raise ValueError(
f"Connector {self._connector_id} is not an MCP connector"
)
access_token = cfg.get("access_token")
if not access_token:
raise ValueError("No access token in MCP connector config")
is_encrypted = cfg.get("_token_encrypted", False)
if is_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
access_token = token_encryption.decrypt_token(access_token)
expires_at_str = cfg.get("expires_at")
if expires_at_str:
expires_at = datetime.fromisoformat(expires_at_str)
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at <= datetime.now(UTC):
from app.routes.notion_mcp_connector_route import refresh_notion_mcp_token
connector = await refresh_notion_mcp_token(self._session, connector)
cfg = connector.config or {}
access_token = cfg.get("access_token", "")
if is_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
access_token = token_encryption.decrypt_token(access_token)
self._access_token = access_token
return access_token
async def _call_mcp_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> str:
"""Connect to Notion MCP server and call a tool. Returns raw text."""
token = await self._get_valid_token()
headers = {"Authorization": f"Bearer {token}"}
async with (
streamablehttp_client(NOTION_MCP_URL, headers=headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(tool_name, arguments=arguments)
return extract_text_from_mcp_response(response)
async def _call_with_fallback(
self,
tool_name: str,
arguments: dict[str, Any],
parser,
fallback_method: str | None = None,
fallback_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Call MCP tool, parse response, and fall back on serialization errors."""
try:
raw_text = await self._call_mcp_tool(tool_name, arguments)
result = parser(raw_text)
if result.get("mcp_serialization_error") and fallback_method:
logger.warning(
"MCP tool '%s' hit serialization bug, falling back to direct API",
tool_name,
)
return await self._fallback(fallback_method, fallback_kwargs or {})
return result
except Exception as e:
error_str = str(e)
if is_mcp_serialization_error(error_str) and fallback_method:
logger.warning(
"MCP tool '%s' raised serialization error, falling back: %s",
tool_name,
error_str,
)
return await self._fallback(fallback_method, fallback_kwargs or {})
logger.error("MCP tool '%s' failed: %s", tool_name, e, exc_info=True)
return {"status": "error", "message": f"MCP call failed: {e!s}"}
async def _fallback(
self, method_name: str, kwargs: dict[str, Any]
) -> dict[str, Any]:
"""Fall back to NotionHistoryConnector for the given method.
Uses the already-refreshed MCP access token directly with the
Notion SDK, bypassing the connector's config-based token loading.
"""
from app.connectors.notion_history import NotionHistoryConnector
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
token = self._access_token
if not token:
token = await self._get_valid_token()
connector = NotionHistoryConnector(
session=self._session,
connector_id=self._connector_id,
)
connector._credentials = NotionAuthCredentialsBase(access_token=token)
connector._using_legacy_token = True
method = getattr(connector, method_name)
return await method(**kwargs)
# ------------------------------------------------------------------
# Public API — same signatures as NotionHistoryConnector
# ------------------------------------------------------------------
async def create_page(
self,
title: str,
content: str,
parent_page_id: str | None = None,
) -> dict[str, Any]:
arguments: dict[str, Any] = {
"pages": [
{
"title": title,
"content": content,
}
]
}
if parent_page_id:
arguments["pages"][0]["parent_page_url"] = parent_page_id
return await self._call_with_fallback(
tool_name="notion-create-pages",
arguments=arguments,
parser=parse_create_page_response,
fallback_method="create_page",
fallback_kwargs={
"title": title,
"content": content,
"parent_page_id": parent_page_id,
},
)
async def update_page(
self,
page_id: str,
content: str | None = None,
) -> dict[str, Any]:
arguments: dict[str, Any] = {
"page_id": page_id,
"command": "replace_content",
}
if content:
arguments["new_str"] = content
return await self._call_with_fallback(
tool_name="notion-update-page",
arguments=arguments,
parser=parse_update_page_response,
fallback_method="update_page",
fallback_kwargs={"page_id": page_id, "content": content},
)
async def delete_page(self, page_id: str) -> dict[str, Any]:
arguments: dict[str, Any] = {
"page_id": page_id,
"command": "update_properties",
"archived": True,
}
return await self._call_with_fallback(
tool_name="notion-update-page",
arguments=arguments,
parser=parse_delete_page_response,
fallback_method="delete_page",
fallback_kwargs={"page_id": page_id},
)
async def fetch_page(self, page_url_or_id: str) -> dict[str, Any]:
"""Fetch page content via ``notion-fetch``."""
raw_text = await self._call_mcp_tool(
"notion-fetch", {"url": page_url_or_id}
)
return parse_fetch_page_response(raw_text)
async def health_check(self) -> dict[str, Any]:
"""Check MCP connection via ``notion-get-self``."""
try:
raw_text = await self._call_mcp_tool("notion-get-self", {})
return parse_health_check_response(raw_text)
except Exception as e:
return {"status": "error", "message": str(e)}