SurfSense/surfsense_backend/app/services/notion_mcp/oauth.py
CREDO23 d6e605fd50 feat(notion-mcp): add OAuth + PKCE service layer and MCP adapter
Implements Notion MCP integration core:
- OAuth 2.0 discovery (RFC 9470 + 8414), dynamic client registration,
  PKCE token exchange, and refresh with rotation
- NotionMCPAdapter connecting to mcp.notion.com/mcp with fallback
  to direct API on known serialization errors
- Response parser translating MCP text responses into dicts matching
  NotionHistoryConnector output format
- has_mcp_notion_connector() helper for connector gating
2026-04-20 20:59:17 +02:00

298 lines
9.6 KiB
Python

"""OAuth 2.0 + PKCE utilities for Notion's remote MCP server.
Implements the flow described in the official guide:
https://developers.notion.com/guides/mcp/build-mcp-client
Steps:
1. Discover OAuth metadata (RFC 9470 → RFC 8414)
2. Dynamic client registration (RFC 7591)
3. Build authorization URL with PKCE code_challenge
4. Exchange authorization code + code_verifier for tokens
5. Refresh access tokens (with refresh-token rotation)
All functions are stateless — callers (route handlers) manage storage.
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
import httpx
logger = logging.getLogger(__name__)
NOTION_MCP_SERVER_URL = "https://mcp.notion.com/mcp"
_HTTP_TIMEOUT = 30.0
@dataclass(frozen=True)
class OAuthMetadata:
issuer: str
authorization_endpoint: str
token_endpoint: str
registration_endpoint: str | None
code_challenge_methods_supported: list[str]
@dataclass(frozen=True)
class ClientCredentials:
client_id: str
client_secret: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None
@dataclass(frozen=True)
class TokenSet:
access_token: str
refresh_token: str | None
token_type: str
expires_in: int | None
expires_at: datetime | None
scope: str | None
# ---------------------------------------------------------------------------
# Step 1 — OAuth discovery
# ---------------------------------------------------------------------------
async def discover_oauth_metadata(
mcp_server_url: str = NOTION_MCP_SERVER_URL,
) -> OAuthMetadata:
"""Discover OAuth endpoints via RFC 9470 + RFC 8414.
1. Fetch protected-resource metadata to find the authorization server.
2. Fetch authorization-server metadata to get OAuth endpoints.
"""
from urllib.parse import urlparse
parsed = urlparse(mcp_server_url)
origin = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
# RFC 9470 — Protected Resource Metadata
# URL format: {origin}/.well-known/oauth-protected-resource{path}
pr_url = f"{origin}/.well-known/oauth-protected-resource{path}"
pr_resp = await client.get(pr_url)
pr_resp.raise_for_status()
pr_data = pr_resp.json()
auth_servers = pr_data.get("authorization_servers", [])
if not auth_servers:
raise ValueError("No authorization_servers in protected resource metadata")
auth_server_url = auth_servers[0]
# RFC 8414 — Authorization Server Metadata
as_url = f"{auth_server_url}/.well-known/oauth-authorization-server"
as_resp = await client.get(as_url)
as_resp.raise_for_status()
as_data = as_resp.json()
if not as_data.get("authorization_endpoint") or not as_data.get("token_endpoint"):
raise ValueError("Missing required OAuth endpoints in server metadata")
return OAuthMetadata(
issuer=as_data.get("issuer", auth_server_url),
authorization_endpoint=as_data["authorization_endpoint"],
token_endpoint=as_data["token_endpoint"],
registration_endpoint=as_data.get("registration_endpoint"),
code_challenge_methods_supported=as_data.get(
"code_challenge_methods_supported", []
),
)
# ---------------------------------------------------------------------------
# Step 2 — Dynamic client registration (RFC 7591)
# ---------------------------------------------------------------------------
async def register_client(
metadata: OAuthMetadata,
redirect_uri: str,
client_name: str = "SurfSense",
) -> ClientCredentials:
"""Dynamically register an OAuth client with the Notion MCP server."""
if not metadata.registration_endpoint:
raise ValueError("Server does not support dynamic client registration")
payload = {
"client_name": client_name,
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.registration_endpoint,
json=payload,
headers={"Content-Type": "application/json", "Accept": "application/json"},
)
if not resp.is_success:
logger.error(
"Dynamic client registration failed (%s): %s",
resp.status_code,
resp.text,
)
resp.raise_for_status()
data = resp.json()
return ClientCredentials(
client_id=data["client_id"],
client_secret=data.get("client_secret"),
client_id_issued_at=data.get("client_id_issued_at"),
client_secret_expires_at=data.get("client_secret_expires_at"),
)
# ---------------------------------------------------------------------------
# Step 3 — Build authorization URL
# ---------------------------------------------------------------------------
def build_authorization_url(
metadata: OAuthMetadata,
client_id: str,
redirect_uri: str,
code_challenge: str,
state: str,
) -> str:
"""Build the OAuth authorization URL with PKCE parameters."""
from urllib.parse import urlencode
params = {
"response_type": "code",
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"state": state,
"prompt": "consent",
}
return f"{metadata.authorization_endpoint}?{urlencode(params)}"
# ---------------------------------------------------------------------------
# Step 4 — Exchange authorization code for tokens
# ---------------------------------------------------------------------------
async def exchange_code_for_tokens(
code: str,
code_verifier: str,
metadata: OAuthMetadata,
client_id: str,
redirect_uri: str,
client_secret: str | None = None,
) -> TokenSet:
"""Exchange an authorization code + PKCE verifier for tokens."""
form_data: dict[str, Any] = {
"grant_type": "authorization_code",
"code": code,
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
if client_secret:
form_data["client_secret"] = client_secret
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.token_endpoint,
data=form_data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
if not resp.is_success:
body = resp.text
raise ValueError(f"Token exchange failed ({resp.status_code}): {body}")
tokens = resp.json()
if not tokens.get("access_token"):
raise ValueError("No access_token in token response")
expires_at = None
if tokens.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
return TokenSet(
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
token_type=tokens.get("token_type", "Bearer"),
expires_in=tokens.get("expires_in"),
expires_at=expires_at,
scope=tokens.get("scope"),
)
# ---------------------------------------------------------------------------
# Step 5 — Refresh access token
# ---------------------------------------------------------------------------
async def refresh_access_token(
refresh_token: str,
metadata: OAuthMetadata,
client_id: str,
client_secret: str | None = None,
) -> TokenSet:
"""Refresh an access token.
Notion MCP uses refresh-token rotation: each refresh returns a new
refresh_token and invalidates the old one. Callers MUST persist the
new refresh_token atomically with the new access_token.
"""
form_data: dict[str, Any] = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
form_data["client_secret"] = client_secret
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.token_endpoint,
data=form_data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
if not resp.is_success:
body = resp.text
try:
error_data = resp.json()
error_code = error_data.get("error", "")
if error_code == "invalid_grant":
raise ValueError("REAUTH_REQUIRED")
except ValueError:
if "REAUTH_REQUIRED" in str(resp.text) or resp.status_code == 401:
raise
raise ValueError(f"Token refresh failed ({resp.status_code}): {body}")
tokens = resp.json()
if not tokens.get("access_token"):
raise ValueError("No access_token in refresh response")
expires_at = None
if tokens.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
return TokenSet(
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
token_type=tokens.get("token_type", "Bearer"),
expires_in=tokens.get("expires_in"),
expires_at=expires_at,
scope=tokens.get("scope"),
)