mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 21:32:39 +02:00
feat(notion-mcp): add OAuth + PKCE service layer and MCP adapter
Implements Notion MCP integration core: - OAuth 2.0 discovery (RFC 9470 + 8414), dynamic client registration, PKCE token exchange, and refresh with rotation - NotionMCPAdapter connecting to mcp.notion.com/mcp with fallback to direct API on known serialization errors - Response parser translating MCP text responses into dicts matching NotionHistoryConnector output format - has_mcp_notion_connector() helper for connector gating
This commit is contained in:
parent
2b2453e015
commit
d6e605fd50
4 changed files with 790 additions and 0 deletions
27
surfsense_backend/app/services/notion_mcp/__init__.py
Normal file
27
surfsense_backend/app/services/notion_mcp/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Notion MCP integration.
|
||||
|
||||
Routes Notion operations through Notion's hosted MCP server
|
||||
at https://mcp.notion.com/mcp instead of direct API calls.
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
|
||||
async def has_mcp_notion_connector(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> bool:
|
||||
"""Check whether the search space has at least one MCP-mode Notion connector."""
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector.id, SearchSourceConnector.config).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR,
|
||||
)
|
||||
)
|
||||
for _, config in result.all():
|
||||
if isinstance(config, dict) and config.get("mcp_mode"):
|
||||
return True
|
||||
return False
|
||||
253
surfsense_backend/app/services/notion_mcp/adapter.py
Normal file
253
surfsense_backend/app/services/notion_mcp/adapter.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""Notion MCP Adapter.
|
||||
|
||||
Connects to Notion's hosted MCP server at ``https://mcp.notion.com/mcp``
|
||||
and exposes the same method signatures as ``NotionHistoryConnector``'s
|
||||
write operations so that tool factories can swap with a one-line change.
|
||||
|
||||
Includes an optional fallback to ``NotionHistoryConnector`` when the MCP
|
||||
server returns known serialization errors (GitHub issues #215, #216).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector
|
||||
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
from .response_parser import (
|
||||
extract_text_from_mcp_response,
|
||||
is_mcp_serialization_error,
|
||||
parse_create_page_response,
|
||||
parse_delete_page_response,
|
||||
parse_fetch_page_response,
|
||||
parse_health_check_response,
|
||||
parse_update_page_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOTION_MCP_URL = "https://mcp.notion.com/mcp"
|
||||
|
||||
|
||||
class NotionMCPAdapter:
|
||||
"""Routes Notion operations through the hosted MCP server.
|
||||
|
||||
Drop-in replacement for ``NotionHistoryConnector`` write methods.
|
||||
Returns the same dict structure so KB sync works unchanged.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, connector_id: int):
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
self._access_token: str | None = None
|
||||
|
||||
async def _get_valid_token(self) -> str:
|
||||
"""Get a valid MCP access token, refreshing if expired."""
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {self._connector_id} not found")
|
||||
|
||||
cfg = connector.config or {}
|
||||
|
||||
if not cfg.get("mcp_mode"):
|
||||
raise ValueError(
|
||||
f"Connector {self._connector_id} is not an MCP connector"
|
||||
)
|
||||
|
||||
access_token = cfg.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError("No access token in MCP connector config")
|
||||
|
||||
is_encrypted = cfg.get("_token_encrypted", False)
|
||||
if is_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
access_token = token_encryption.decrypt_token(access_token)
|
||||
|
||||
expires_at_str = cfg.get("expires_at")
|
||||
if expires_at_str:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
if expires_at <= datetime.now(UTC):
|
||||
from app.routes.notion_mcp_connector_route import refresh_notion_mcp_token
|
||||
|
||||
connector = await refresh_notion_mcp_token(self._session, connector)
|
||||
cfg = connector.config or {}
|
||||
access_token = cfg.get("access_token", "")
|
||||
if is_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
access_token = token_encryption.decrypt_token(access_token)
|
||||
|
||||
self._access_token = access_token
|
||||
return access_token
|
||||
|
||||
async def _call_mcp_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> str:
|
||||
"""Connect to Notion MCP server and call a tool. Returns raw text."""
|
||||
token = await self._get_valid_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with (
|
||||
streamablehttp_client(NOTION_MCP_URL, headers=headers) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.call_tool(tool_name, arguments=arguments)
|
||||
return extract_text_from_mcp_response(response)
|
||||
|
||||
async def _call_with_fallback(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
parser,
|
||||
fallback_method: str | None = None,
|
||||
fallback_kwargs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Call MCP tool, parse response, and fall back on serialization errors."""
|
||||
try:
|
||||
raw_text = await self._call_mcp_tool(tool_name, arguments)
|
||||
result = parser(raw_text)
|
||||
|
||||
if result.get("mcp_serialization_error") and fallback_method:
|
||||
logger.warning(
|
||||
"MCP tool '%s' hit serialization bug, falling back to direct API",
|
||||
tool_name,
|
||||
)
|
||||
return await self._fallback(fallback_method, fallback_kwargs or {})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if is_mcp_serialization_error(error_str) and fallback_method:
|
||||
logger.warning(
|
||||
"MCP tool '%s' raised serialization error, falling back: %s",
|
||||
tool_name,
|
||||
error_str,
|
||||
)
|
||||
return await self._fallback(fallback_method, fallback_kwargs or {})
|
||||
|
||||
logger.error("MCP tool '%s' failed: %s", tool_name, e, exc_info=True)
|
||||
return {"status": "error", "message": f"MCP call failed: {e!s}"}
|
||||
|
||||
async def _fallback(
|
||||
self, method_name: str, kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Fall back to NotionHistoryConnector for the given method.
|
||||
|
||||
Uses the already-refreshed MCP access token directly with the
|
||||
Notion SDK, bypassing the connector's config-based token loading.
|
||||
"""
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
|
||||
|
||||
token = self._access_token
|
||||
if not token:
|
||||
token = await self._get_valid_token()
|
||||
|
||||
connector = NotionHistoryConnector(
|
||||
session=self._session,
|
||||
connector_id=self._connector_id,
|
||||
)
|
||||
connector._credentials = NotionAuthCredentialsBase(access_token=token)
|
||||
connector._using_legacy_token = True
|
||||
|
||||
method = getattr(connector, method_name)
|
||||
return await method(**kwargs)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — same signatures as NotionHistoryConnector
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_page(
|
||||
self,
|
||||
title: str,
|
||||
content: str,
|
||||
parent_page_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
arguments: dict[str, Any] = {
|
||||
"pages": [
|
||||
{
|
||||
"title": title,
|
||||
"content": content,
|
||||
}
|
||||
]
|
||||
}
|
||||
if parent_page_id:
|
||||
arguments["pages"][0]["parent_page_url"] = parent_page_id
|
||||
|
||||
return await self._call_with_fallback(
|
||||
tool_name="notion-create-pages",
|
||||
arguments=arguments,
|
||||
parser=parse_create_page_response,
|
||||
fallback_method="create_page",
|
||||
fallback_kwargs={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"parent_page_id": parent_page_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def update_page(
|
||||
self,
|
||||
page_id: str,
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
arguments: dict[str, Any] = {
|
||||
"page_id": page_id,
|
||||
"command": "replace_content",
|
||||
}
|
||||
if content:
|
||||
arguments["new_str"] = content
|
||||
|
||||
return await self._call_with_fallback(
|
||||
tool_name="notion-update-page",
|
||||
arguments=arguments,
|
||||
parser=parse_update_page_response,
|
||||
fallback_method="update_page",
|
||||
fallback_kwargs={"page_id": page_id, "content": content},
|
||||
)
|
||||
|
||||
async def delete_page(self, page_id: str) -> dict[str, Any]:
|
||||
arguments: dict[str, Any] = {
|
||||
"page_id": page_id,
|
||||
"command": "update_properties",
|
||||
"archived": True,
|
||||
}
|
||||
|
||||
return await self._call_with_fallback(
|
||||
tool_name="notion-update-page",
|
||||
arguments=arguments,
|
||||
parser=parse_delete_page_response,
|
||||
fallback_method="delete_page",
|
||||
fallback_kwargs={"page_id": page_id},
|
||||
)
|
||||
|
||||
async def fetch_page(self, page_url_or_id: str) -> dict[str, Any]:
|
||||
"""Fetch page content via ``notion-fetch``."""
|
||||
raw_text = await self._call_mcp_tool(
|
||||
"notion-fetch", {"url": page_url_or_id}
|
||||
)
|
||||
return parse_fetch_page_response(raw_text)
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check MCP connection via ``notion-get-self``."""
|
||||
try:
|
||||
raw_text = await self._call_mcp_tool("notion-get-self", {})
|
||||
return parse_health_check_response(raw_text)
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
298
surfsense_backend/app/services/notion_mcp/oauth.py
Normal file
298
surfsense_backend/app/services/notion_mcp/oauth.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
"""OAuth 2.0 + PKCE utilities for Notion's remote MCP server.
|
||||
|
||||
Implements the flow described in the official guide:
|
||||
https://developers.notion.com/guides/mcp/build-mcp-client
|
||||
|
||||
Steps:
|
||||
1. Discover OAuth metadata (RFC 9470 → RFC 8414)
|
||||
2. Dynamic client registration (RFC 7591)
|
||||
3. Build authorization URL with PKCE code_challenge
|
||||
4. Exchange authorization code + code_verifier for tokens
|
||||
5. Refresh access tokens (with refresh-token rotation)
|
||||
|
||||
All functions are stateless — callers (route handlers) manage storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOTION_MCP_SERVER_URL = "https://mcp.notion.com/mcp"
|
||||
_HTTP_TIMEOUT = 30.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OAuthMetadata:
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
registration_endpoint: str | None
|
||||
code_challenge_methods_supported: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClientCredentials:
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
client_id_issued_at: int | None = None
|
||||
client_secret_expires_at: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenSet:
|
||||
access_token: str
|
||||
refresh_token: str | None
|
||||
token_type: str
|
||||
expires_in: int | None
|
||||
expires_at: datetime | None
|
||||
scope: str | None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 1 — OAuth discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def discover_oauth_metadata(
|
||||
mcp_server_url: str = NOTION_MCP_SERVER_URL,
|
||||
) -> OAuthMetadata:
|
||||
"""Discover OAuth endpoints via RFC 9470 + RFC 8414.
|
||||
|
||||
1. Fetch protected-resource metadata to find the authorization server.
|
||||
2. Fetch authorization-server metadata to get OAuth endpoints.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(mcp_server_url)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
# RFC 9470 — Protected Resource Metadata
|
||||
# URL format: {origin}/.well-known/oauth-protected-resource{path}
|
||||
pr_url = f"{origin}/.well-known/oauth-protected-resource{path}"
|
||||
pr_resp = await client.get(pr_url)
|
||||
pr_resp.raise_for_status()
|
||||
pr_data = pr_resp.json()
|
||||
|
||||
auth_servers = pr_data.get("authorization_servers", [])
|
||||
if not auth_servers:
|
||||
raise ValueError("No authorization_servers in protected resource metadata")
|
||||
auth_server_url = auth_servers[0]
|
||||
|
||||
# RFC 8414 — Authorization Server Metadata
|
||||
as_url = f"{auth_server_url}/.well-known/oauth-authorization-server"
|
||||
as_resp = await client.get(as_url)
|
||||
as_resp.raise_for_status()
|
||||
as_data = as_resp.json()
|
||||
|
||||
if not as_data.get("authorization_endpoint") or not as_data.get("token_endpoint"):
|
||||
raise ValueError("Missing required OAuth endpoints in server metadata")
|
||||
|
||||
return OAuthMetadata(
|
||||
issuer=as_data.get("issuer", auth_server_url),
|
||||
authorization_endpoint=as_data["authorization_endpoint"],
|
||||
token_endpoint=as_data["token_endpoint"],
|
||||
registration_endpoint=as_data.get("registration_endpoint"),
|
||||
code_challenge_methods_supported=as_data.get(
|
||||
"code_challenge_methods_supported", []
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 — Dynamic client registration (RFC 7591)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def register_client(
|
||||
metadata: OAuthMetadata,
|
||||
redirect_uri: str,
|
||||
client_name: str = "SurfSense",
|
||||
) -> ClientCredentials:
|
||||
"""Dynamically register an OAuth client with the Notion MCP server."""
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Server does not support dynamic client registration")
|
||||
|
||||
payload = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
resp = await client.post(
|
||||
metadata.registration_endpoint,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
)
|
||||
if not resp.is_success:
|
||||
logger.error(
|
||||
"Dynamic client registration failed (%s): %s",
|
||||
resp.status_code,
|
||||
resp.text,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return ClientCredentials(
|
||||
client_id=data["client_id"],
|
||||
client_secret=data.get("client_secret"),
|
||||
client_id_issued_at=data.get("client_id_issued_at"),
|
||||
client_secret_expires_at=data.get("client_secret_expires_at"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 3 — Build authorization URL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_authorization_url(
|
||||
metadata: OAuthMetadata,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_challenge: str,
|
||||
state: str,
|
||||
) -> str:
|
||||
"""Build the OAuth authorization URL with PKCE parameters."""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
"prompt": "consent",
|
||||
}
|
||||
return f"{metadata.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 4 — Exchange authorization code for tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
code: str,
|
||||
code_verifier: str,
|
||||
metadata: OAuthMetadata,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
client_secret: str | None = None,
|
||||
) -> TokenSet:
|
||||
"""Exchange an authorization code + PKCE verifier for tokens."""
|
||||
form_data: dict[str, Any] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
if client_secret:
|
||||
form_data["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
resp = await client.post(
|
||||
metadata.token_endpoint,
|
||||
data=form_data,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
if not resp.is_success:
|
||||
body = resp.text
|
||||
raise ValueError(f"Token exchange failed ({resp.status_code}): {body}")
|
||||
tokens = resp.json()
|
||||
|
||||
if not tokens.get("access_token"):
|
||||
raise ValueError("No access_token in token response")
|
||||
|
||||
expires_at = None
|
||||
if tokens.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
|
||||
|
||||
return TokenSet(
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
token_type=tokens.get("token_type", "Bearer"),
|
||||
expires_in=tokens.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=tokens.get("scope"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 5 — Refresh access token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def refresh_access_token(
|
||||
refresh_token: str,
|
||||
metadata: OAuthMetadata,
|
||||
client_id: str,
|
||||
client_secret: str | None = None,
|
||||
) -> TokenSet:
|
||||
"""Refresh an access token.
|
||||
|
||||
Notion MCP uses refresh-token rotation: each refresh returns a new
|
||||
refresh_token and invalidates the old one. Callers MUST persist the
|
||||
new refresh_token atomically with the new access_token.
|
||||
"""
|
||||
form_data: dict[str, Any] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
form_data["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
resp = await client.post(
|
||||
metadata.token_endpoint,
|
||||
data=form_data,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if not resp.is_success:
|
||||
body = resp.text
|
||||
try:
|
||||
error_data = resp.json()
|
||||
error_code = error_data.get("error", "")
|
||||
if error_code == "invalid_grant":
|
||||
raise ValueError("REAUTH_REQUIRED")
|
||||
except ValueError:
|
||||
if "REAUTH_REQUIRED" in str(resp.text) or resp.status_code == 401:
|
||||
raise
|
||||
raise ValueError(f"Token refresh failed ({resp.status_code}): {body}")
|
||||
|
||||
tokens = resp.json()
|
||||
|
||||
if not tokens.get("access_token"):
|
||||
raise ValueError("No access_token in refresh response")
|
||||
|
||||
expires_at = None
|
||||
if tokens.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
|
||||
|
||||
return TokenSet(
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
token_type=tokens.get("token_type", "Bearer"),
|
||||
expires_in=tokens.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=tokens.get("scope"),
|
||||
)
|
||||
212
surfsense_backend/app/services/notion_mcp/response_parser.py
Normal file
212
surfsense_backend/app/services/notion_mcp/response_parser.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
"""Parse Notion MCP tool responses into structured dicts.
|
||||
|
||||
The Notion MCP server returns responses as MCP TextContent where the
|
||||
``text`` field contains JSON-stringified Notion API response data.
|
||||
See: https://deepwiki.com/makenotion/notion-mcp-server/4.3-request-and-response-handling
|
||||
|
||||
This module extracts that JSON and normalises it into the same dict
|
||||
format that ``NotionHistoryConnector`` methods return, so downstream
|
||||
code (KB sync, tool factories) works unchanged.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MCP_SERIALIZATION_ERROR_MARKERS = [
|
||||
"Expected array, received string",
|
||||
"Expected object, received string",
|
||||
"should be defined, instead was `undefined`",
|
||||
]
|
||||
|
||||
|
||||
def is_mcp_serialization_error(text: str) -> bool:
|
||||
"""Return True if the MCP error text matches a known serialization bug."""
|
||||
return any(marker in text for marker in MCP_SERIALIZATION_ERROR_MARKERS)
|
||||
|
||||
|
||||
def extract_text_from_mcp_response(response) -> str:
|
||||
"""Pull the concatenated text out of an MCP ``CallToolResult``.
|
||||
|
||||
Args:
|
||||
response: The ``CallToolResult`` returned by ``session.call_tool()``.
|
||||
|
||||
Returns:
|
||||
Concatenated text content from the response.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
parts.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
parts.append(str(content.data))
|
||||
else:
|
||||
parts.append(str(content))
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _try_parse_json(text: str) -> dict[str, Any] | None:
|
||||
"""Attempt to parse *text* as JSON, returning None on failure."""
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _extract_page_title(page_data: dict[str, Any]) -> str:
|
||||
"""Best-effort extraction of the page title from a Notion page object."""
|
||||
props = page_data.get("properties", {})
|
||||
for prop in props.values():
|
||||
if prop.get("type") == "title":
|
||||
title_parts = prop.get("title", [])
|
||||
if title_parts:
|
||||
return " ".join(t.get("plain_text", "") for t in title_parts)
|
||||
return page_data.get("id", "Untitled")
|
||||
|
||||
|
||||
def parse_create_page_response(raw_text: str) -> dict[str, Any]:
|
||||
"""Parse a ``notion-create-pages`` MCP response.
|
||||
|
||||
Returns a dict compatible with ``NotionHistoryConnector.create_page()``:
|
||||
``{status, page_id, url, title, message}``
|
||||
"""
|
||||
data = _try_parse_json(raw_text)
|
||||
|
||||
if data is None:
|
||||
if is_mcp_serialization_error(raw_text):
|
||||
return {
|
||||
"status": "mcp_error",
|
||||
"message": raw_text,
|
||||
"mcp_serialization_error": True,
|
||||
}
|
||||
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
|
||||
|
||||
if data.get("status") == "error" or "error" in data:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": data.get("message", data.get("error", str(data))),
|
||||
}
|
||||
|
||||
page_id = data.get("id", "")
|
||||
url = data.get("url", "")
|
||||
title = _extract_page_title(data)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"url": url,
|
||||
"title": title,
|
||||
"message": f"Created Notion page '{title}'",
|
||||
}
|
||||
|
||||
|
||||
def parse_update_page_response(raw_text: str) -> dict[str, Any]:
|
||||
"""Parse a ``notion-update-page`` MCP response.
|
||||
|
||||
Returns a dict compatible with ``NotionHistoryConnector.update_page()``:
|
||||
``{status, page_id, url, title, message}``
|
||||
"""
|
||||
data = _try_parse_json(raw_text)
|
||||
|
||||
if data is None:
|
||||
if is_mcp_serialization_error(raw_text):
|
||||
return {
|
||||
"status": "mcp_error",
|
||||
"message": raw_text,
|
||||
"mcp_serialization_error": True,
|
||||
}
|
||||
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
|
||||
|
||||
if data.get("status") == "error" or "error" in data:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": data.get("message", data.get("error", str(data))),
|
||||
}
|
||||
|
||||
page_id = data.get("id", "")
|
||||
url = data.get("url", "")
|
||||
title = _extract_page_title(data)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"url": url,
|
||||
"title": title,
|
||||
"message": f"Updated Notion page '{title}' (content appended)",
|
||||
}
|
||||
|
||||
|
||||
def parse_delete_page_response(raw_text: str) -> dict[str, Any]:
|
||||
"""Parse an archive (delete) MCP response.
|
||||
|
||||
The Notion API responds to ``pages.update(archived=True)`` with
|
||||
the archived page object.
|
||||
|
||||
Returns a dict compatible with ``NotionHistoryConnector.delete_page()``:
|
||||
``{status, page_id, message}``
|
||||
"""
|
||||
data = _try_parse_json(raw_text)
|
||||
|
||||
if data is None:
|
||||
if is_mcp_serialization_error(raw_text):
|
||||
return {
|
||||
"status": "mcp_error",
|
||||
"message": raw_text,
|
||||
"mcp_serialization_error": True,
|
||||
}
|
||||
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
|
||||
|
||||
if data.get("status") == "error" or "error" in data:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": data.get("message", data.get("error", str(data))),
|
||||
}
|
||||
|
||||
page_id = data.get("id", "")
|
||||
title = _extract_page_title(data)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"message": f"Deleted Notion page '{title}'",
|
||||
}
|
||||
|
||||
|
||||
def parse_fetch_page_response(raw_text: str) -> dict[str, Any]:
|
||||
"""Parse a ``notion-fetch`` MCP response.
|
||||
|
||||
Returns the raw parsed dict (Notion page/block data) or an error dict.
|
||||
"""
|
||||
data = _try_parse_json(raw_text)
|
||||
|
||||
if data is None:
|
||||
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
|
||||
|
||||
if data.get("status") == "error" or "error" in data:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": data.get("message", data.get("error", str(data))),
|
||||
}
|
||||
|
||||
return {"status": "success", "data": data}
|
||||
|
||||
|
||||
def parse_health_check_response(raw_text: str) -> dict[str, Any]:
|
||||
"""Parse a ``notion-get-self`` MCP response for health checking."""
|
||||
data = _try_parse_json(raw_text)
|
||||
|
||||
if data is None:
|
||||
return {"status": "error", "message": raw_text[:500]}
|
||||
|
||||
if data.get("status") == "error" or "error" in data:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": data.get("message", data.get("error", str(data))),
|
||||
}
|
||||
|
||||
return {"status": "success", "data": data}
|
||||
Loading…
Add table
Add a link
Reference in a new issue