mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 23:32:40 +02:00
feat(notion-mcp): add OAuth + PKCE service layer and MCP adapter
Implements Notion MCP integration core: - OAuth 2.0 discovery (RFC 9470 + 8414), dynamic client registration, PKCE token exchange, and refresh with rotation - NotionMCPAdapter connecting to mcp.notion.com/mcp with fallback to direct API on known serialization errors - Response parser translating MCP text responses into dicts matching NotionHistoryConnector output format - has_mcp_notion_connector() helper for connector gating
This commit is contained in:
parent
2b2453e015
commit
d6e605fd50
4 changed files with 790 additions and 0 deletions
298
surfsense_backend/app/services/notion_mcp/oauth.py
Normal file
298
surfsense_backend/app/services/notion_mcp/oauth.py
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
"""OAuth 2.0 + PKCE utilities for Notion's remote MCP server.
|
||||
|
||||
Implements the flow described in the official guide:
|
||||
https://developers.notion.com/guides/mcp/build-mcp-client
|
||||
|
||||
Steps:
|
||||
1. Discover OAuth metadata (RFC 9470 → RFC 8414)
|
||||
2. Dynamic client registration (RFC 7591)
|
||||
3. Build authorization URL with PKCE code_challenge
|
||||
4. Exchange authorization code + code_verifier for tokens
|
||||
5. Refresh access tokens (with refresh-token rotation)
|
||||
|
||||
All functions are stateless — callers (route handlers) manage storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOTION_MCP_SERVER_URL = "https://mcp.notion.com/mcp"
|
||||
_HTTP_TIMEOUT = 30.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OAuthMetadata:
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
registration_endpoint: str | None
|
||||
code_challenge_methods_supported: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClientCredentials:
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
client_id_issued_at: int | None = None
|
||||
client_secret_expires_at: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenSet:
|
||||
access_token: str
|
||||
refresh_token: str | None
|
||||
token_type: str
|
||||
expires_in: int | None
|
||||
expires_at: datetime | None
|
||||
scope: str | None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 1 — OAuth discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def discover_oauth_metadata(
|
||||
mcp_server_url: str = NOTION_MCP_SERVER_URL,
|
||||
) -> OAuthMetadata:
|
||||
"""Discover OAuth endpoints via RFC 9470 + RFC 8414.
|
||||
|
||||
1. Fetch protected-resource metadata to find the authorization server.
|
||||
2. Fetch authorization-server metadata to get OAuth endpoints.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(mcp_server_url)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
# RFC 9470 — Protected Resource Metadata
|
||||
# URL format: {origin}/.well-known/oauth-protected-resource{path}
|
||||
pr_url = f"{origin}/.well-known/oauth-protected-resource{path}"
|
||||
pr_resp = await client.get(pr_url)
|
||||
pr_resp.raise_for_status()
|
||||
pr_data = pr_resp.json()
|
||||
|
||||
auth_servers = pr_data.get("authorization_servers", [])
|
||||
if not auth_servers:
|
||||
raise ValueError("No authorization_servers in protected resource metadata")
|
||||
auth_server_url = auth_servers[0]
|
||||
|
||||
# RFC 8414 — Authorization Server Metadata
|
||||
as_url = f"{auth_server_url}/.well-known/oauth-authorization-server"
|
||||
as_resp = await client.get(as_url)
|
||||
as_resp.raise_for_status()
|
||||
as_data = as_resp.json()
|
||||
|
||||
if not as_data.get("authorization_endpoint") or not as_data.get("token_endpoint"):
|
||||
raise ValueError("Missing required OAuth endpoints in server metadata")
|
||||
|
||||
return OAuthMetadata(
|
||||
issuer=as_data.get("issuer", auth_server_url),
|
||||
authorization_endpoint=as_data["authorization_endpoint"],
|
||||
token_endpoint=as_data["token_endpoint"],
|
||||
registration_endpoint=as_data.get("registration_endpoint"),
|
||||
code_challenge_methods_supported=as_data.get(
|
||||
"code_challenge_methods_supported", []
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 — Dynamic client registration (RFC 7591)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def register_client(
|
||||
metadata: OAuthMetadata,
|
||||
redirect_uri: str,
|
||||
client_name: str = "SurfSense",
|
||||
) -> ClientCredentials:
|
||||
"""Dynamically register an OAuth client with the Notion MCP server."""
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Server does not support dynamic client registration")
|
||||
|
||||
payload = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
resp = await client.post(
|
||||
metadata.registration_endpoint,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
)
|
||||
if not resp.is_success:
|
||||
logger.error(
|
||||
"Dynamic client registration failed (%s): %s",
|
||||
resp.status_code,
|
||||
resp.text,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return ClientCredentials(
|
||||
client_id=data["client_id"],
|
||||
client_secret=data.get("client_secret"),
|
||||
client_id_issued_at=data.get("client_id_issued_at"),
|
||||
client_secret_expires_at=data.get("client_secret_expires_at"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 3 — Build authorization URL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_authorization_url(
|
||||
metadata: OAuthMetadata,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_challenge: str,
|
||||
state: str,
|
||||
) -> str:
|
||||
"""Build the OAuth authorization URL with PKCE parameters."""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
"prompt": "consent",
|
||||
}
|
||||
return f"{metadata.authorization_endpoint}?{urlencode(params)}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 4 — Exchange authorization code for tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
code: str,
|
||||
code_verifier: str,
|
||||
metadata: OAuthMetadata,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
client_secret: str | None = None,
|
||||
) -> TokenSet:
|
||||
"""Exchange an authorization code + PKCE verifier for tokens."""
|
||||
form_data: dict[str, Any] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
if client_secret:
|
||||
form_data["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
resp = await client.post(
|
||||
metadata.token_endpoint,
|
||||
data=form_data,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
if not resp.is_success:
|
||||
body = resp.text
|
||||
raise ValueError(f"Token exchange failed ({resp.status_code}): {body}")
|
||||
tokens = resp.json()
|
||||
|
||||
if not tokens.get("access_token"):
|
||||
raise ValueError("No access_token in token response")
|
||||
|
||||
expires_at = None
|
||||
if tokens.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
|
||||
|
||||
return TokenSet(
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
token_type=tokens.get("token_type", "Bearer"),
|
||||
expires_in=tokens.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=tokens.get("scope"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 5 — Refresh access token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def refresh_access_token(
|
||||
refresh_token: str,
|
||||
metadata: OAuthMetadata,
|
||||
client_id: str,
|
||||
client_secret: str | None = None,
|
||||
) -> TokenSet:
|
||||
"""Refresh an access token.
|
||||
|
||||
Notion MCP uses refresh-token rotation: each refresh returns a new
|
||||
refresh_token and invalidates the old one. Callers MUST persist the
|
||||
new refresh_token atomically with the new access_token.
|
||||
"""
|
||||
form_data: dict[str, Any] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}
|
||||
if client_secret:
|
||||
form_data["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
|
||||
resp = await client.post(
|
||||
metadata.token_endpoint,
|
||||
data=form_data,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if not resp.is_success:
|
||||
body = resp.text
|
||||
try:
|
||||
error_data = resp.json()
|
||||
error_code = error_data.get("error", "")
|
||||
if error_code == "invalid_grant":
|
||||
raise ValueError("REAUTH_REQUIRED")
|
||||
except ValueError:
|
||||
if "REAUTH_REQUIRED" in str(resp.text) or resp.status_code == 401:
|
||||
raise
|
||||
raise ValueError(f"Token refresh failed ({resp.status_code}): {body}")
|
||||
|
||||
tokens = resp.json()
|
||||
|
||||
if not tokens.get("access_token"):
|
||||
raise ValueError("No access_token in refresh response")
|
||||
|
||||
expires_at = None
|
||||
if tokens.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
|
||||
|
||||
return TokenSet(
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
token_type=tokens.get("token_type", "Bearer"),
|
||||
expires_in=tokens.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=tokens.get("scope"),
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue