feat(notion-mcp): add OAuth + PKCE service layer and MCP adapter

Implements Notion MCP integration core:
- OAuth 2.0 discovery (RFC 9470 + 8414), dynamic client registration,
  PKCE token exchange, and refresh with rotation
- NotionMCPAdapter connecting to mcp.notion.com/mcp with fallback
  to direct API on known serialization errors
- Response parser translating MCP text responses into dicts matching
  NotionHistoryConnector output format
- has_mcp_notion_connector() helper for connector gating
This commit is contained in:
CREDO23 2026-04-20 20:59:17 +02:00
parent 2b2453e015
commit d6e605fd50
4 changed files with 790 additions and 0 deletions

View file

@ -0,0 +1,27 @@
"""Notion MCP integration.
Routes Notion operations through Notion's hosted MCP server
at https://mcp.notion.com/mcp instead of direct API calls.
"""
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSourceConnector, SearchSourceConnectorType
async def has_mcp_notion_connector(
session: AsyncSession,
search_space_id: int,
) -> bool:
"""Check whether the search space has at least one MCP-mode Notion connector."""
result = await session.execute(
select(SearchSourceConnector.id, SearchSourceConnector.config).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
for _, config in result.all():
if isinstance(config, dict) and config.get("mcp_mode"):
return True
return False

View file

@ -0,0 +1,253 @@
"""Notion MCP Adapter.
Connects to Notion's hosted MCP server at ``https://mcp.notion.com/mcp``
and exposes the same method signatures as ``NotionHistoryConnector``'s
write operations so that tool factories can swap with a one-line change.
Includes an optional fallback to ``NotionHistoryConnector`` when the MCP
server returns known serialization errors (GitHub issues #215, #216).
"""
import logging
from datetime import UTC, datetime
from typing import Any
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import SearchSourceConnector
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
from app.utils.oauth_security import TokenEncryption
from .response_parser import (
extract_text_from_mcp_response,
is_mcp_serialization_error,
parse_create_page_response,
parse_delete_page_response,
parse_fetch_page_response,
parse_health_check_response,
parse_update_page_response,
)
logger = logging.getLogger(__name__)
NOTION_MCP_URL = "https://mcp.notion.com/mcp"
class NotionMCPAdapter:
"""Routes Notion operations through the hosted MCP server.
Drop-in replacement for ``NotionHistoryConnector`` write methods.
Returns the same dict structure so KB sync works unchanged.
"""
def __init__(self, session: AsyncSession, connector_id: int):
self._session = session
self._connector_id = connector_id
self._access_token: str | None = None
async def _get_valid_token(self) -> str:
"""Get a valid MCP access token, refreshing if expired."""
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise ValueError(f"Connector {self._connector_id} not found")
cfg = connector.config or {}
if not cfg.get("mcp_mode"):
raise ValueError(
f"Connector {self._connector_id} is not an MCP connector"
)
access_token = cfg.get("access_token")
if not access_token:
raise ValueError("No access token in MCP connector config")
is_encrypted = cfg.get("_token_encrypted", False)
if is_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
access_token = token_encryption.decrypt_token(access_token)
expires_at_str = cfg.get("expires_at")
if expires_at_str:
expires_at = datetime.fromisoformat(expires_at_str)
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at <= datetime.now(UTC):
from app.routes.notion_mcp_connector_route import refresh_notion_mcp_token
connector = await refresh_notion_mcp_token(self._session, connector)
cfg = connector.config or {}
access_token = cfg.get("access_token", "")
if is_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
access_token = token_encryption.decrypt_token(access_token)
self._access_token = access_token
return access_token
async def _call_mcp_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> str:
"""Connect to Notion MCP server and call a tool. Returns raw text."""
token = await self._get_valid_token()
headers = {"Authorization": f"Bearer {token}"}
async with (
streamablehttp_client(NOTION_MCP_URL, headers=headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(tool_name, arguments=arguments)
return extract_text_from_mcp_response(response)
async def _call_with_fallback(
self,
tool_name: str,
arguments: dict[str, Any],
parser,
fallback_method: str | None = None,
fallback_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Call MCP tool, parse response, and fall back on serialization errors."""
try:
raw_text = await self._call_mcp_tool(tool_name, arguments)
result = parser(raw_text)
if result.get("mcp_serialization_error") and fallback_method:
logger.warning(
"MCP tool '%s' hit serialization bug, falling back to direct API",
tool_name,
)
return await self._fallback(fallback_method, fallback_kwargs or {})
return result
except Exception as e:
error_str = str(e)
if is_mcp_serialization_error(error_str) and fallback_method:
logger.warning(
"MCP tool '%s' raised serialization error, falling back: %s",
tool_name,
error_str,
)
return await self._fallback(fallback_method, fallback_kwargs or {})
logger.error("MCP tool '%s' failed: %s", tool_name, e, exc_info=True)
return {"status": "error", "message": f"MCP call failed: {e!s}"}
async def _fallback(
self, method_name: str, kwargs: dict[str, Any]
) -> dict[str, Any]:
"""Fall back to NotionHistoryConnector for the given method.
Uses the already-refreshed MCP access token directly with the
Notion SDK, bypassing the connector's config-based token loading.
"""
from app.connectors.notion_history import NotionHistoryConnector
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
token = self._access_token
if not token:
token = await self._get_valid_token()
connector = NotionHistoryConnector(
session=self._session,
connector_id=self._connector_id,
)
connector._credentials = NotionAuthCredentialsBase(access_token=token)
connector._using_legacy_token = True
method = getattr(connector, method_name)
return await method(**kwargs)
# ------------------------------------------------------------------
# Public API — same signatures as NotionHistoryConnector
# ------------------------------------------------------------------
async def create_page(
self,
title: str,
content: str,
parent_page_id: str | None = None,
) -> dict[str, Any]:
arguments: dict[str, Any] = {
"pages": [
{
"title": title,
"content": content,
}
]
}
if parent_page_id:
arguments["pages"][0]["parent_page_url"] = parent_page_id
return await self._call_with_fallback(
tool_name="notion-create-pages",
arguments=arguments,
parser=parse_create_page_response,
fallback_method="create_page",
fallback_kwargs={
"title": title,
"content": content,
"parent_page_id": parent_page_id,
},
)
async def update_page(
self,
page_id: str,
content: str | None = None,
) -> dict[str, Any]:
arguments: dict[str, Any] = {
"page_id": page_id,
"command": "replace_content",
}
if content:
arguments["new_str"] = content
return await self._call_with_fallback(
tool_name="notion-update-page",
arguments=arguments,
parser=parse_update_page_response,
fallback_method="update_page",
fallback_kwargs={"page_id": page_id, "content": content},
)
async def delete_page(self, page_id: str) -> dict[str, Any]:
arguments: dict[str, Any] = {
"page_id": page_id,
"command": "update_properties",
"archived": True,
}
return await self._call_with_fallback(
tool_name="notion-update-page",
arguments=arguments,
parser=parse_delete_page_response,
fallback_method="delete_page",
fallback_kwargs={"page_id": page_id},
)
async def fetch_page(self, page_url_or_id: str) -> dict[str, Any]:
"""Fetch page content via ``notion-fetch``."""
raw_text = await self._call_mcp_tool(
"notion-fetch", {"url": page_url_or_id}
)
return parse_fetch_page_response(raw_text)
async def health_check(self) -> dict[str, Any]:
"""Check MCP connection via ``notion-get-self``."""
try:
raw_text = await self._call_mcp_tool("notion-get-self", {})
return parse_health_check_response(raw_text)
except Exception as e:
return {"status": "error", "message": str(e)}

View file

@ -0,0 +1,298 @@
"""OAuth 2.0 + PKCE utilities for Notion's remote MCP server.
Implements the flow described in the official guide:
https://developers.notion.com/guides/mcp/build-mcp-client
Steps:
1. Discover OAuth metadata (RFC 9470 RFC 8414)
2. Dynamic client registration (RFC 7591)
3. Build authorization URL with PKCE code_challenge
4. Exchange authorization code + code_verifier for tokens
5. Refresh access tokens (with refresh-token rotation)
All functions are stateless callers (route handlers) manage storage.
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
import httpx
logger = logging.getLogger(__name__)
NOTION_MCP_SERVER_URL = "https://mcp.notion.com/mcp"
_HTTP_TIMEOUT = 30.0
@dataclass(frozen=True)
class OAuthMetadata:
issuer: str
authorization_endpoint: str
token_endpoint: str
registration_endpoint: str | None
code_challenge_methods_supported: list[str]
@dataclass(frozen=True)
class ClientCredentials:
client_id: str
client_secret: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None
@dataclass(frozen=True)
class TokenSet:
access_token: str
refresh_token: str | None
token_type: str
expires_in: int | None
expires_at: datetime | None
scope: str | None
# ---------------------------------------------------------------------------
# Step 1 — OAuth discovery
# ---------------------------------------------------------------------------
async def discover_oauth_metadata(
mcp_server_url: str = NOTION_MCP_SERVER_URL,
) -> OAuthMetadata:
"""Discover OAuth endpoints via RFC 9470 + RFC 8414.
1. Fetch protected-resource metadata to find the authorization server.
2. Fetch authorization-server metadata to get OAuth endpoints.
"""
from urllib.parse import urlparse
parsed = urlparse(mcp_server_url)
origin = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
# RFC 9470 — Protected Resource Metadata
# URL format: {origin}/.well-known/oauth-protected-resource{path}
pr_url = f"{origin}/.well-known/oauth-protected-resource{path}"
pr_resp = await client.get(pr_url)
pr_resp.raise_for_status()
pr_data = pr_resp.json()
auth_servers = pr_data.get("authorization_servers", [])
if not auth_servers:
raise ValueError("No authorization_servers in protected resource metadata")
auth_server_url = auth_servers[0]
# RFC 8414 — Authorization Server Metadata
as_url = f"{auth_server_url}/.well-known/oauth-authorization-server"
as_resp = await client.get(as_url)
as_resp.raise_for_status()
as_data = as_resp.json()
if not as_data.get("authorization_endpoint") or not as_data.get("token_endpoint"):
raise ValueError("Missing required OAuth endpoints in server metadata")
return OAuthMetadata(
issuer=as_data.get("issuer", auth_server_url),
authorization_endpoint=as_data["authorization_endpoint"],
token_endpoint=as_data["token_endpoint"],
registration_endpoint=as_data.get("registration_endpoint"),
code_challenge_methods_supported=as_data.get(
"code_challenge_methods_supported", []
),
)
# ---------------------------------------------------------------------------
# Step 2 — Dynamic client registration (RFC 7591)
# ---------------------------------------------------------------------------
async def register_client(
metadata: OAuthMetadata,
redirect_uri: str,
client_name: str = "SurfSense",
) -> ClientCredentials:
"""Dynamically register an OAuth client with the Notion MCP server."""
if not metadata.registration_endpoint:
raise ValueError("Server does not support dynamic client registration")
payload = {
"client_name": client_name,
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.registration_endpoint,
json=payload,
headers={"Content-Type": "application/json", "Accept": "application/json"},
)
if not resp.is_success:
logger.error(
"Dynamic client registration failed (%s): %s",
resp.status_code,
resp.text,
)
resp.raise_for_status()
data = resp.json()
return ClientCredentials(
client_id=data["client_id"],
client_secret=data.get("client_secret"),
client_id_issued_at=data.get("client_id_issued_at"),
client_secret_expires_at=data.get("client_secret_expires_at"),
)
# ---------------------------------------------------------------------------
# Step 3 — Build authorization URL
# ---------------------------------------------------------------------------
def build_authorization_url(
metadata: OAuthMetadata,
client_id: str,
redirect_uri: str,
code_challenge: str,
state: str,
) -> str:
"""Build the OAuth authorization URL with PKCE parameters."""
from urllib.parse import urlencode
params = {
"response_type": "code",
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"state": state,
"prompt": "consent",
}
return f"{metadata.authorization_endpoint}?{urlencode(params)}"
# ---------------------------------------------------------------------------
# Step 4 — Exchange authorization code for tokens
# ---------------------------------------------------------------------------
async def exchange_code_for_tokens(
code: str,
code_verifier: str,
metadata: OAuthMetadata,
client_id: str,
redirect_uri: str,
client_secret: str | None = None,
) -> TokenSet:
"""Exchange an authorization code + PKCE verifier for tokens."""
form_data: dict[str, Any] = {
"grant_type": "authorization_code",
"code": code,
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
if client_secret:
form_data["client_secret"] = client_secret
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.token_endpoint,
data=form_data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
if not resp.is_success:
body = resp.text
raise ValueError(f"Token exchange failed ({resp.status_code}): {body}")
tokens = resp.json()
if not tokens.get("access_token"):
raise ValueError("No access_token in token response")
expires_at = None
if tokens.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
return TokenSet(
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
token_type=tokens.get("token_type", "Bearer"),
expires_in=tokens.get("expires_in"),
expires_at=expires_at,
scope=tokens.get("scope"),
)
# ---------------------------------------------------------------------------
# Step 5 — Refresh access token
# ---------------------------------------------------------------------------
async def refresh_access_token(
refresh_token: str,
metadata: OAuthMetadata,
client_id: str,
client_secret: str | None = None,
) -> TokenSet:
"""Refresh an access token.
Notion MCP uses refresh-token rotation: each refresh returns a new
refresh_token and invalidates the old one. Callers MUST persist the
new refresh_token atomically with the new access_token.
"""
form_data: dict[str, Any] = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
form_data["client_secret"] = client_secret
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.token_endpoint,
data=form_data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
if not resp.is_success:
body = resp.text
try:
error_data = resp.json()
error_code = error_data.get("error", "")
if error_code == "invalid_grant":
raise ValueError("REAUTH_REQUIRED")
except ValueError:
if "REAUTH_REQUIRED" in str(resp.text) or resp.status_code == 401:
raise
raise ValueError(f"Token refresh failed ({resp.status_code}): {body}")
tokens = resp.json()
if not tokens.get("access_token"):
raise ValueError("No access_token in refresh response")
expires_at = None
if tokens.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
return TokenSet(
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
token_type=tokens.get("token_type", "Bearer"),
expires_in=tokens.get("expires_in"),
expires_at=expires_at,
scope=tokens.get("scope"),
)

View file

@ -0,0 +1,212 @@
"""Parse Notion MCP tool responses into structured dicts.
The Notion MCP server returns responses as MCP TextContent where the
``text`` field contains JSON-stringified Notion API response data.
See: https://deepwiki.com/makenotion/notion-mcp-server/4.3-request-and-response-handling
This module extracts that JSON and normalises it into the same dict
format that ``NotionHistoryConnector`` methods return, so downstream
code (KB sync, tool factories) works unchanged.
"""
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
MCP_SERIALIZATION_ERROR_MARKERS = [
"Expected array, received string",
"Expected object, received string",
"should be defined, instead was `undefined`",
]
def is_mcp_serialization_error(text: str) -> bool:
"""Return True if the MCP error text matches a known serialization bug."""
return any(marker in text for marker in MCP_SERIALIZATION_ERROR_MARKERS)
def extract_text_from_mcp_response(response) -> str:
"""Pull the concatenated text out of an MCP ``CallToolResult``.
Args:
response: The ``CallToolResult`` returned by ``session.call_tool()``.
Returns:
Concatenated text content from the response.
"""
parts: list[str] = []
for content in response.content:
if hasattr(content, "text"):
parts.append(content.text)
elif hasattr(content, "data"):
parts.append(str(content.data))
else:
parts.append(str(content))
return "\n".join(parts) if parts else ""
def _try_parse_json(text: str) -> dict[str, Any] | None:
"""Attempt to parse *text* as JSON, returning None on failure."""
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except (json.JSONDecodeError, TypeError):
pass
return None
def _extract_page_title(page_data: dict[str, Any]) -> str:
"""Best-effort extraction of the page title from a Notion page object."""
props = page_data.get("properties", {})
for prop in props.values():
if prop.get("type") == "title":
title_parts = prop.get("title", [])
if title_parts:
return " ".join(t.get("plain_text", "") for t in title_parts)
return page_data.get("id", "Untitled")
def parse_create_page_response(raw_text: str) -> dict[str, Any]:
"""Parse a ``notion-create-pages`` MCP response.
Returns a dict compatible with ``NotionHistoryConnector.create_page()``:
``{status, page_id, url, title, message}``
"""
data = _try_parse_json(raw_text)
if data is None:
if is_mcp_serialization_error(raw_text):
return {
"status": "mcp_error",
"message": raw_text,
"mcp_serialization_error": True,
}
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
if data.get("status") == "error" or "error" in data:
return {
"status": "error",
"message": data.get("message", data.get("error", str(data))),
}
page_id = data.get("id", "")
url = data.get("url", "")
title = _extract_page_title(data)
return {
"status": "success",
"page_id": page_id,
"url": url,
"title": title,
"message": f"Created Notion page '{title}'",
}
def parse_update_page_response(raw_text: str) -> dict[str, Any]:
"""Parse a ``notion-update-page`` MCP response.
Returns a dict compatible with ``NotionHistoryConnector.update_page()``:
``{status, page_id, url, title, message}``
"""
data = _try_parse_json(raw_text)
if data is None:
if is_mcp_serialization_error(raw_text):
return {
"status": "mcp_error",
"message": raw_text,
"mcp_serialization_error": True,
}
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
if data.get("status") == "error" or "error" in data:
return {
"status": "error",
"message": data.get("message", data.get("error", str(data))),
}
page_id = data.get("id", "")
url = data.get("url", "")
title = _extract_page_title(data)
return {
"status": "success",
"page_id": page_id,
"url": url,
"title": title,
"message": f"Updated Notion page '{title}' (content appended)",
}
def parse_delete_page_response(raw_text: str) -> dict[str, Any]:
"""Parse an archive (delete) MCP response.
The Notion API responds to ``pages.update(archived=True)`` with
the archived page object.
Returns a dict compatible with ``NotionHistoryConnector.delete_page()``:
``{status, page_id, message}``
"""
data = _try_parse_json(raw_text)
if data is None:
if is_mcp_serialization_error(raw_text):
return {
"status": "mcp_error",
"message": raw_text,
"mcp_serialization_error": True,
}
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
if data.get("status") == "error" or "error" in data:
return {
"status": "error",
"message": data.get("message", data.get("error", str(data))),
}
page_id = data.get("id", "")
title = _extract_page_title(data)
return {
"status": "success",
"page_id": page_id,
"message": f"Deleted Notion page '{title}'",
}
def parse_fetch_page_response(raw_text: str) -> dict[str, Any]:
"""Parse a ``notion-fetch`` MCP response.
Returns the raw parsed dict (Notion page/block data) or an error dict.
"""
data = _try_parse_json(raw_text)
if data is None:
return {"status": "error", "message": f"Unexpected MCP response: {raw_text[:500]}"}
if data.get("status") == "error" or "error" in data:
return {
"status": "error",
"message": data.get("message", data.get("error", str(data))),
}
return {"status": "success", "data": data}
def parse_health_check_response(raw_text: str) -> dict[str, Any]:
"""Parse a ``notion-get-self`` MCP response for health checking."""
data = _try_parse_json(raw_text)
if data is None:
return {"status": "error", "message": raw_text[:500]}
if data.get("status") == "error" or "error" in data:
return {
"status": "error",
"message": data.get("message", data.get("error", str(data))),
}
return {"status": "success", "data": data}