mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 23:02:39 +02:00
Implements Notion MCP integration core: - OAuth 2.0 discovery (RFC 9470 + 8414), dynamic client registration, PKCE token exchange, and refresh with rotation - NotionMCPAdapter connecting to mcp.notion.com/mcp with fallback to direct API on known serialization errors - Response parser translating MCP text responses into dicts matching NotionHistoryConnector output format - has_mcp_notion_connector() helper for connector gating
298 lines
9.6 KiB
Python
298 lines
9.6 KiB
Python
"""OAuth 2.0 + PKCE utilities for Notion's remote MCP server.
|
|
|
|
Implements the flow described in the official guide:
|
|
https://developers.notion.com/guides/mcp/build-mcp-client
|
|
|
|
Steps:
|
|
1. Discover OAuth metadata (RFC 9470 → RFC 8414)
|
|
2. Dynamic client registration (RFC 7591)
|
|
3. Build authorization URL with PKCE code_challenge
|
|
4. Exchange authorization code + code_verifier for tokens
|
|
5. Refresh access tokens (with refresh-token rotation)
|
|
|
|
All functions are stateless — callers (route handlers) manage storage.
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NOTION_MCP_SERVER_URL = "https://mcp.notion.com/mcp"
|
|
_HTTP_TIMEOUT = 30.0
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class OAuthMetadata:
|
|
issuer: str
|
|
authorization_endpoint: str
|
|
token_endpoint: str
|
|
registration_endpoint: str | None
|
|
code_challenge_methods_supported: list[str]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ClientCredentials:
|
|
client_id: str
|
|
client_secret: str | None = None
|
|
client_id_issued_at: int | None = None
|
|
client_secret_expires_at: int | None = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TokenSet:
|
|
access_token: str
|
|
refresh_token: str | None
|
|
token_type: str
|
|
expires_in: int | None
|
|
expires_at: datetime | None
|
|
scope: str | None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Step 1 — OAuth discovery
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def discover_oauth_metadata(
|
|
mcp_server_url: str = NOTION_MCP_SERVER_URL,
|
|
) -> OAuthMetadata:
|
|
"""Discover OAuth endpoints via RFC 9470 + RFC 8414.
|
|
|
|
1. Fetch protected-resource metadata to find the authorization server.
|
|
2. Fetch authorization-server metadata to get OAuth endpoints.
|
|
"""
|
|
from urllib.parse import urlparse
|
|
|
|
parsed = urlparse(mcp_server_url)
|
|
origin = f"{parsed.scheme}://{parsed.netloc}"
|
|
path = parsed.path.rstrip("/")
|
|
|
|
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
|
# RFC 9470 — Protected Resource Metadata
|
|
# URL format: {origin}/.well-known/oauth-protected-resource{path}
|
|
pr_url = f"{origin}/.well-known/oauth-protected-resource{path}"
|
|
pr_resp = await client.get(pr_url)
|
|
pr_resp.raise_for_status()
|
|
pr_data = pr_resp.json()
|
|
|
|
auth_servers = pr_data.get("authorization_servers", [])
|
|
if not auth_servers:
|
|
raise ValueError("No authorization_servers in protected resource metadata")
|
|
auth_server_url = auth_servers[0]
|
|
|
|
# RFC 8414 — Authorization Server Metadata
|
|
as_url = f"{auth_server_url}/.well-known/oauth-authorization-server"
|
|
as_resp = await client.get(as_url)
|
|
as_resp.raise_for_status()
|
|
as_data = as_resp.json()
|
|
|
|
if not as_data.get("authorization_endpoint") or not as_data.get("token_endpoint"):
|
|
raise ValueError("Missing required OAuth endpoints in server metadata")
|
|
|
|
return OAuthMetadata(
|
|
issuer=as_data.get("issuer", auth_server_url),
|
|
authorization_endpoint=as_data["authorization_endpoint"],
|
|
token_endpoint=as_data["token_endpoint"],
|
|
registration_endpoint=as_data.get("registration_endpoint"),
|
|
code_challenge_methods_supported=as_data.get(
|
|
"code_challenge_methods_supported", []
|
|
),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Step 2 — Dynamic client registration (RFC 7591)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def register_client(
|
|
metadata: OAuthMetadata,
|
|
redirect_uri: str,
|
|
client_name: str = "SurfSense",
|
|
) -> ClientCredentials:
|
|
"""Dynamically register an OAuth client with the Notion MCP server."""
|
|
if not metadata.registration_endpoint:
|
|
raise ValueError("Server does not support dynamic client registration")
|
|
|
|
payload = {
|
|
"client_name": client_name,
|
|
"redirect_uris": [redirect_uri],
|
|
"grant_types": ["authorization_code", "refresh_token"],
|
|
"response_types": ["code"],
|
|
"token_endpoint_auth_method": "none",
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
|
resp = await client.post(
|
|
metadata.registration_endpoint,
|
|
json=payload,
|
|
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
|
)
|
|
if not resp.is_success:
|
|
logger.error(
|
|
"Dynamic client registration failed (%s): %s",
|
|
resp.status_code,
|
|
resp.text,
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
return ClientCredentials(
|
|
client_id=data["client_id"],
|
|
client_secret=data.get("client_secret"),
|
|
client_id_issued_at=data.get("client_id_issued_at"),
|
|
client_secret_expires_at=data.get("client_secret_expires_at"),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Step 3 — Build authorization URL
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def build_authorization_url(
|
|
metadata: OAuthMetadata,
|
|
client_id: str,
|
|
redirect_uri: str,
|
|
code_challenge: str,
|
|
state: str,
|
|
) -> str:
|
|
"""Build the OAuth authorization URL with PKCE parameters."""
|
|
from urllib.parse import urlencode
|
|
|
|
params = {
|
|
"response_type": "code",
|
|
"client_id": client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"code_challenge": code_challenge,
|
|
"code_challenge_method": "S256",
|
|
"state": state,
|
|
"prompt": "consent",
|
|
}
|
|
return f"{metadata.authorization_endpoint}?{urlencode(params)}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Step 4 — Exchange authorization code for tokens
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def exchange_code_for_tokens(
|
|
code: str,
|
|
code_verifier: str,
|
|
metadata: OAuthMetadata,
|
|
client_id: str,
|
|
redirect_uri: str,
|
|
client_secret: str | None = None,
|
|
) -> TokenSet:
|
|
"""Exchange an authorization code + PKCE verifier for tokens."""
|
|
form_data: dict[str, Any] = {
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"client_id": client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"code_verifier": code_verifier,
|
|
}
|
|
if client_secret:
|
|
form_data["client_secret"] = client_secret
|
|
|
|
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
|
resp = await client.post(
|
|
metadata.token_endpoint,
|
|
data=form_data,
|
|
headers={
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Accept": "application/json",
|
|
},
|
|
)
|
|
if not resp.is_success:
|
|
body = resp.text
|
|
raise ValueError(f"Token exchange failed ({resp.status_code}): {body}")
|
|
tokens = resp.json()
|
|
|
|
if not tokens.get("access_token"):
|
|
raise ValueError("No access_token in token response")
|
|
|
|
expires_at = None
|
|
if tokens.get("expires_in"):
|
|
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
|
|
|
|
return TokenSet(
|
|
access_token=tokens["access_token"],
|
|
refresh_token=tokens.get("refresh_token"),
|
|
token_type=tokens.get("token_type", "Bearer"),
|
|
expires_in=tokens.get("expires_in"),
|
|
expires_at=expires_at,
|
|
scope=tokens.get("scope"),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Step 5 — Refresh access token
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def refresh_access_token(
|
|
refresh_token: str,
|
|
metadata: OAuthMetadata,
|
|
client_id: str,
|
|
client_secret: str | None = None,
|
|
) -> TokenSet:
|
|
"""Refresh an access token.
|
|
|
|
Notion MCP uses refresh-token rotation: each refresh returns a new
|
|
refresh_token and invalidates the old one. Callers MUST persist the
|
|
new refresh_token atomically with the new access_token.
|
|
"""
|
|
form_data: dict[str, Any] = {
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": client_id,
|
|
}
|
|
if client_secret:
|
|
form_data["client_secret"] = client_secret
|
|
|
|
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
|
resp = await client.post(
|
|
metadata.token_endpoint,
|
|
data=form_data,
|
|
headers={
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Accept": "application/json",
|
|
},
|
|
)
|
|
|
|
if not resp.is_success:
|
|
body = resp.text
|
|
try:
|
|
error_data = resp.json()
|
|
error_code = error_data.get("error", "")
|
|
if error_code == "invalid_grant":
|
|
raise ValueError("REAUTH_REQUIRED")
|
|
except ValueError:
|
|
if "REAUTH_REQUIRED" in str(resp.text) or resp.status_code == 401:
|
|
raise
|
|
raise ValueError(f"Token refresh failed ({resp.status_code}): {body}")
|
|
|
|
tokens = resp.json()
|
|
|
|
if not tokens.get("access_token"):
|
|
raise ValueError("No access_token in refresh response")
|
|
|
|
expires_at = None
|
|
if tokens.get("expires_in"):
|
|
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
|
|
|
|
return TokenSet(
|
|
access_token=tokens["access_token"],
|
|
refresh_token=tokens.get("refresh_token"),
|
|
token_type=tokens.get("token_type", "Bearer"),
|
|
expires_in=tokens.get("expires_in"),
|
|
expires_at=expires_at,
|
|
scope=tokens.get("scope"),
|
|
)
|