SurfSense/surfsense_backend/app/services/notion_mcp/oauth.py

299 lines
9.6 KiB
Python
Raw Normal View History

"""OAuth 2.0 + PKCE utilities for Notion's remote MCP server.
Implements the flow described in the official guide:
https://developers.notion.com/guides/mcp/build-mcp-client
Steps:
1. Discover OAuth metadata (RFC 9470 RFC 8414)
2. Dynamic client registration (RFC 7591)
3. Build authorization URL with PKCE code_challenge
4. Exchange authorization code + code_verifier for tokens
5. Refresh access tokens (with refresh-token rotation)
All functions are stateless callers (route handlers) manage storage.
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
import httpx
logger = logging.getLogger(__name__)
NOTION_MCP_SERVER_URL = "https://mcp.notion.com/mcp"
_HTTP_TIMEOUT = 30.0
@dataclass(frozen=True)
class OAuthMetadata:
issuer: str
authorization_endpoint: str
token_endpoint: str
registration_endpoint: str | None
code_challenge_methods_supported: list[str]
@dataclass(frozen=True)
class ClientCredentials:
client_id: str
client_secret: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None
@dataclass(frozen=True)
class TokenSet:
access_token: str
refresh_token: str | None
token_type: str
expires_in: int | None
expires_at: datetime | None
scope: str | None
# ---------------------------------------------------------------------------
# Step 1 — OAuth discovery
# ---------------------------------------------------------------------------
async def discover_oauth_metadata(
mcp_server_url: str = NOTION_MCP_SERVER_URL,
) -> OAuthMetadata:
"""Discover OAuth endpoints via RFC 9470 + RFC 8414.
1. Fetch protected-resource metadata to find the authorization server.
2. Fetch authorization-server metadata to get OAuth endpoints.
"""
from urllib.parse import urlparse
parsed = urlparse(mcp_server_url)
origin = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
# RFC 9470 — Protected Resource Metadata
# URL format: {origin}/.well-known/oauth-protected-resource{path}
pr_url = f"{origin}/.well-known/oauth-protected-resource{path}"
pr_resp = await client.get(pr_url)
pr_resp.raise_for_status()
pr_data = pr_resp.json()
auth_servers = pr_data.get("authorization_servers", [])
if not auth_servers:
raise ValueError("No authorization_servers in protected resource metadata")
auth_server_url = auth_servers[0]
# RFC 8414 — Authorization Server Metadata
as_url = f"{auth_server_url}/.well-known/oauth-authorization-server"
as_resp = await client.get(as_url)
as_resp.raise_for_status()
as_data = as_resp.json()
if not as_data.get("authorization_endpoint") or not as_data.get("token_endpoint"):
raise ValueError("Missing required OAuth endpoints in server metadata")
return OAuthMetadata(
issuer=as_data.get("issuer", auth_server_url),
authorization_endpoint=as_data["authorization_endpoint"],
token_endpoint=as_data["token_endpoint"],
registration_endpoint=as_data.get("registration_endpoint"),
code_challenge_methods_supported=as_data.get(
"code_challenge_methods_supported", []
),
)
# ---------------------------------------------------------------------------
# Step 2 — Dynamic client registration (RFC 7591)
# ---------------------------------------------------------------------------
async def register_client(
metadata: OAuthMetadata,
redirect_uri: str,
client_name: str = "SurfSense",
) -> ClientCredentials:
"""Dynamically register an OAuth client with the Notion MCP server."""
if not metadata.registration_endpoint:
raise ValueError("Server does not support dynamic client registration")
payload = {
"client_name": client_name,
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.registration_endpoint,
json=payload,
headers={"Content-Type": "application/json", "Accept": "application/json"},
)
if not resp.is_success:
logger.error(
"Dynamic client registration failed (%s): %s",
resp.status_code,
resp.text,
)
resp.raise_for_status()
data = resp.json()
return ClientCredentials(
client_id=data["client_id"],
client_secret=data.get("client_secret"),
client_id_issued_at=data.get("client_id_issued_at"),
client_secret_expires_at=data.get("client_secret_expires_at"),
)
# ---------------------------------------------------------------------------
# Step 3 — Build authorization URL
# ---------------------------------------------------------------------------
def build_authorization_url(
metadata: OAuthMetadata,
client_id: str,
redirect_uri: str,
code_challenge: str,
state: str,
) -> str:
"""Build the OAuth authorization URL with PKCE parameters."""
from urllib.parse import urlencode
params = {
"response_type": "code",
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"state": state,
"prompt": "consent",
}
return f"{metadata.authorization_endpoint}?{urlencode(params)}"
# ---------------------------------------------------------------------------
# Step 4 — Exchange authorization code for tokens
# ---------------------------------------------------------------------------
async def exchange_code_for_tokens(
code: str,
code_verifier: str,
metadata: OAuthMetadata,
client_id: str,
redirect_uri: str,
client_secret: str | None = None,
) -> TokenSet:
"""Exchange an authorization code + PKCE verifier for tokens."""
form_data: dict[str, Any] = {
"grant_type": "authorization_code",
"code": code,
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
if client_secret:
form_data["client_secret"] = client_secret
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.token_endpoint,
data=form_data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
if not resp.is_success:
body = resp.text
raise ValueError(f"Token exchange failed ({resp.status_code}): {body}")
tokens = resp.json()
if not tokens.get("access_token"):
raise ValueError("No access_token in token response")
expires_at = None
if tokens.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
return TokenSet(
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
token_type=tokens.get("token_type", "Bearer"),
expires_in=tokens.get("expires_in"),
expires_at=expires_at,
scope=tokens.get("scope"),
)
# ---------------------------------------------------------------------------
# Step 5 — Refresh access token
# ---------------------------------------------------------------------------
async def refresh_access_token(
refresh_token: str,
metadata: OAuthMetadata,
client_id: str,
client_secret: str | None = None,
) -> TokenSet:
"""Refresh an access token.
Notion MCP uses refresh-token rotation: each refresh returns a new
refresh_token and invalidates the old one. Callers MUST persist the
new refresh_token atomically with the new access_token.
"""
form_data: dict[str, Any] = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
form_data["client_secret"] = client_secret
async with httpx.AsyncClient(timeout=_HTTP_TIMEOUT) as client:
resp = await client.post(
metadata.token_endpoint,
data=form_data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
)
if not resp.is_success:
body = resp.text
try:
error_data = resp.json()
error_code = error_data.get("error", "")
if error_code == "invalid_grant":
raise ValueError("REAUTH_REQUIRED")
except ValueError:
if "REAUTH_REQUIRED" in str(resp.text) or resp.status_code == 401:
raise
raise ValueError(f"Token refresh failed ({resp.status_code}): {body}")
tokens = resp.json()
if not tokens.get("access_token"):
raise ValueError("No access_token in refresh response")
expires_at = None
if tokens.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(seconds=int(tokens["expires_in"]))
return TokenSet(
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
token_type=tokens.get("token_type", "Bearer"),
expires_in=tokens.get("expires_in"),
expires_at=expires_at,
scope=tokens.get("scope"),
)