SurfSense/surfsense_backend/app/services/mcp_oauth/discovery.py
2026-04-21 21:38:24 +02:00

121 lines
3.6 KiB
Python

"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
from __future__ import annotations
import base64
import logging
from urllib.parse import urlparse
import httpx
logger = logging.getLogger(__name__)
async def discover_oauth_metadata(
mcp_url: str,
*,
origin_override: str | None = None,
timeout: float = 15.0,
) -> dict:
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
Per the MCP spec the discovery document lives at the *origin* of the
MCP server URL. ``origin_override`` can be used when the OAuth server
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
OAuth at ``airtable.com``).
"""
if origin_override:
origin = origin_override.rstrip("/")
else:
parsed = urlparse(mcp_url)
origin = f"{parsed.scheme}://{parsed.netloc}"
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.get(discovery_url, timeout=timeout)
resp.raise_for_status()
return resp.json()
async def register_client(
registration_endpoint: str,
redirect_uri: str,
*,
client_name: str = "SurfSense",
timeout: float = 15.0,
) -> dict:
"""Perform Dynamic Client Registration (RFC 7591)."""
payload = {
"client_name": client_name,
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_basic",
}
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
registration_endpoint, json=payload, timeout=timeout,
)
resp.raise_for_status()
return resp.json()
async def exchange_code_for_tokens(
token_endpoint: str,
code: str,
redirect_uri: str,
client_id: str,
client_secret: str,
code_verifier: str,
*,
timeout: float = 30.0,
) -> dict:
"""Exchange an authorization code for access + refresh tokens."""
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
token_endpoint,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
},
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {creds}",
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()
async def refresh_access_token(
token_endpoint: str,
refresh_token: str,
client_id: str,
client_secret: str,
*,
timeout: float = 30.0,
) -> dict:
"""Refresh an expired access token."""
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
token_endpoint,
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
},
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {creds}",
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()