mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 07:12:39 +02:00
Implements Notion MCP integration core: - OAuth 2.0 discovery (RFC 9470 + 8414), dynamic client registration, PKCE token exchange, and refresh with rotation - NotionMCPAdapter connecting to mcp.notion.com/mcp with fallback to direct API on known serialization errors - Response parser translating MCP text responses into dicts matching NotionHistoryConnector output format - has_mcp_notion_connector() helper for connector gating
253 lines
9 KiB
Python
253 lines
9 KiB
Python
"""Notion MCP Adapter.
|
|
|
|
Connects to Notion's hosted MCP server at ``https://mcp.notion.com/mcp``
|
|
and exposes the same method signatures as ``NotionHistoryConnector``'s
|
|
write operations so that tool factories can swap with a one-line change.
|
|
|
|
Includes an optional fallback to ``NotionHistoryConnector`` when the MCP
|
|
server returns known serialization errors (GitHub issues #215, #216).
|
|
"""
|
|
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from mcp import ClientSession
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import config
|
|
from app.db import SearchSourceConnector
|
|
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
|
|
from app.utils.oauth_security import TokenEncryption
|
|
|
|
from .response_parser import (
|
|
extract_text_from_mcp_response,
|
|
is_mcp_serialization_error,
|
|
parse_create_page_response,
|
|
parse_delete_page_response,
|
|
parse_fetch_page_response,
|
|
parse_health_check_response,
|
|
parse_update_page_response,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NOTION_MCP_URL = "https://mcp.notion.com/mcp"
|
|
|
|
|
|
class NotionMCPAdapter:
|
|
"""Routes Notion operations through the hosted MCP server.
|
|
|
|
Drop-in replacement for ``NotionHistoryConnector`` write methods.
|
|
Returns the same dict structure so KB sync works unchanged.
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession, connector_id: int):
|
|
self._session = session
|
|
self._connector_id = connector_id
|
|
self._access_token: str | None = None
|
|
|
|
async def _get_valid_token(self) -> str:
|
|
"""Get a valid MCP access token, refreshing if expired."""
|
|
result = await self._session.execute(
|
|
select(SearchSourceConnector).filter(
|
|
SearchSourceConnector.id == self._connector_id
|
|
)
|
|
)
|
|
connector = result.scalars().first()
|
|
if not connector:
|
|
raise ValueError(f"Connector {self._connector_id} not found")
|
|
|
|
cfg = connector.config or {}
|
|
|
|
if not cfg.get("mcp_mode"):
|
|
raise ValueError(
|
|
f"Connector {self._connector_id} is not an MCP connector"
|
|
)
|
|
|
|
access_token = cfg.get("access_token")
|
|
if not access_token:
|
|
raise ValueError("No access token in MCP connector config")
|
|
|
|
is_encrypted = cfg.get("_token_encrypted", False)
|
|
if is_encrypted and config.SECRET_KEY:
|
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
access_token = token_encryption.decrypt_token(access_token)
|
|
|
|
expires_at_str = cfg.get("expires_at")
|
|
if expires_at_str:
|
|
expires_at = datetime.fromisoformat(expires_at_str)
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=UTC)
|
|
if expires_at <= datetime.now(UTC):
|
|
from app.routes.notion_mcp_connector_route import refresh_notion_mcp_token
|
|
|
|
connector = await refresh_notion_mcp_token(self._session, connector)
|
|
cfg = connector.config or {}
|
|
access_token = cfg.get("access_token", "")
|
|
if is_encrypted and config.SECRET_KEY:
|
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
access_token = token_encryption.decrypt_token(access_token)
|
|
|
|
self._access_token = access_token
|
|
return access_token
|
|
|
|
async def _call_mcp_tool(
|
|
self, tool_name: str, arguments: dict[str, Any]
|
|
) -> str:
|
|
"""Connect to Notion MCP server and call a tool. Returns raw text."""
|
|
token = await self._get_valid_token()
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
|
|
async with (
|
|
streamablehttp_client(NOTION_MCP_URL, headers=headers) as (read, write, _),
|
|
ClientSession(read, write) as session,
|
|
):
|
|
await session.initialize()
|
|
response = await session.call_tool(tool_name, arguments=arguments)
|
|
return extract_text_from_mcp_response(response)
|
|
|
|
async def _call_with_fallback(
|
|
self,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
parser,
|
|
fallback_method: str | None = None,
|
|
fallback_kwargs: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Call MCP tool, parse response, and fall back on serialization errors."""
|
|
try:
|
|
raw_text = await self._call_mcp_tool(tool_name, arguments)
|
|
result = parser(raw_text)
|
|
|
|
if result.get("mcp_serialization_error") and fallback_method:
|
|
logger.warning(
|
|
"MCP tool '%s' hit serialization bug, falling back to direct API",
|
|
tool_name,
|
|
)
|
|
return await self._fallback(fallback_method, fallback_kwargs or {})
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
error_str = str(e)
|
|
if is_mcp_serialization_error(error_str) and fallback_method:
|
|
logger.warning(
|
|
"MCP tool '%s' raised serialization error, falling back: %s",
|
|
tool_name,
|
|
error_str,
|
|
)
|
|
return await self._fallback(fallback_method, fallback_kwargs or {})
|
|
|
|
logger.error("MCP tool '%s' failed: %s", tool_name, e, exc_info=True)
|
|
return {"status": "error", "message": f"MCP call failed: {e!s}"}
|
|
|
|
async def _fallback(
|
|
self, method_name: str, kwargs: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
"""Fall back to NotionHistoryConnector for the given method.
|
|
|
|
Uses the already-refreshed MCP access token directly with the
|
|
Notion SDK, bypassing the connector's config-based token loading.
|
|
"""
|
|
from app.connectors.notion_history import NotionHistoryConnector
|
|
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
|
|
|
|
token = self._access_token
|
|
if not token:
|
|
token = await self._get_valid_token()
|
|
|
|
connector = NotionHistoryConnector(
|
|
session=self._session,
|
|
connector_id=self._connector_id,
|
|
)
|
|
connector._credentials = NotionAuthCredentialsBase(access_token=token)
|
|
connector._using_legacy_token = True
|
|
|
|
method = getattr(connector, method_name)
|
|
return await method(**kwargs)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API — same signatures as NotionHistoryConnector
|
|
# ------------------------------------------------------------------
|
|
|
|
async def create_page(
|
|
self,
|
|
title: str,
|
|
content: str,
|
|
parent_page_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
arguments: dict[str, Any] = {
|
|
"pages": [
|
|
{
|
|
"title": title,
|
|
"content": content,
|
|
}
|
|
]
|
|
}
|
|
if parent_page_id:
|
|
arguments["pages"][0]["parent_page_url"] = parent_page_id
|
|
|
|
return await self._call_with_fallback(
|
|
tool_name="notion-create-pages",
|
|
arguments=arguments,
|
|
parser=parse_create_page_response,
|
|
fallback_method="create_page",
|
|
fallback_kwargs={
|
|
"title": title,
|
|
"content": content,
|
|
"parent_page_id": parent_page_id,
|
|
},
|
|
)
|
|
|
|
async def update_page(
|
|
self,
|
|
page_id: str,
|
|
content: str | None = None,
|
|
) -> dict[str, Any]:
|
|
arguments: dict[str, Any] = {
|
|
"page_id": page_id,
|
|
"command": "replace_content",
|
|
}
|
|
if content:
|
|
arguments["new_str"] = content
|
|
|
|
return await self._call_with_fallback(
|
|
tool_name="notion-update-page",
|
|
arguments=arguments,
|
|
parser=parse_update_page_response,
|
|
fallback_method="update_page",
|
|
fallback_kwargs={"page_id": page_id, "content": content},
|
|
)
|
|
|
|
async def delete_page(self, page_id: str) -> dict[str, Any]:
|
|
arguments: dict[str, Any] = {
|
|
"page_id": page_id,
|
|
"command": "update_properties",
|
|
"archived": True,
|
|
}
|
|
|
|
return await self._call_with_fallback(
|
|
tool_name="notion-update-page",
|
|
arguments=arguments,
|
|
parser=parse_delete_page_response,
|
|
fallback_method="delete_page",
|
|
fallback_kwargs={"page_id": page_id},
|
|
)
|
|
|
|
async def fetch_page(self, page_url_or_id: str) -> dict[str, Any]:
|
|
"""Fetch page content via ``notion-fetch``."""
|
|
raw_text = await self._call_mcp_tool(
|
|
"notion-fetch", {"url": page_url_or_id}
|
|
)
|
|
return parse_fetch_page_response(raw_text)
|
|
|
|
async def health_check(self) -> dict[str, Any]:
|
|
"""Check MCP connection via ``notion-get-self``."""
|
|
try:
|
|
raw_text = await self._call_mcp_tool("notion-get-self", {})
|
|
return parse_health_check_response(raw_text)
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|