mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-28 02:23:53 +02:00
121 lines
3.6 KiB
Python
121 lines
3.6 KiB
Python
"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import logging
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def discover_oauth_metadata(
|
|
mcp_url: str,
|
|
*,
|
|
origin_override: str | None = None,
|
|
timeout: float = 15.0,
|
|
) -> dict:
|
|
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
|
|
|
|
Per the MCP spec the discovery document lives at the *origin* of the
|
|
MCP server URL. ``origin_override`` can be used when the OAuth server
|
|
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
|
|
OAuth at ``airtable.com``).
|
|
"""
|
|
if origin_override:
|
|
origin = origin_override.rstrip("/")
|
|
else:
|
|
parsed = urlparse(mcp_url)
|
|
origin = f"{parsed.scheme}://{parsed.netloc}"
|
|
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
|
|
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
resp = await client.get(discovery_url, timeout=timeout)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def register_client(
|
|
registration_endpoint: str,
|
|
redirect_uri: str,
|
|
*,
|
|
client_name: str = "SurfSense",
|
|
timeout: float = 15.0,
|
|
) -> dict:
|
|
"""Perform Dynamic Client Registration (RFC 7591)."""
|
|
payload = {
|
|
"client_name": client_name,
|
|
"redirect_uris": [redirect_uri],
|
|
"grant_types": ["authorization_code", "refresh_token"],
|
|
"response_types": ["code"],
|
|
"token_endpoint_auth_method": "client_secret_basic",
|
|
}
|
|
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
resp = await client.post(
|
|
registration_endpoint, json=payload, timeout=timeout,
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def exchange_code_for_tokens(
|
|
token_endpoint: str,
|
|
code: str,
|
|
redirect_uri: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
code_verifier: str,
|
|
*,
|
|
timeout: float = 30.0,
|
|
) -> dict:
|
|
"""Exchange an authorization code for access + refresh tokens."""
|
|
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
|
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
resp = await client.post(
|
|
token_endpoint,
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": redirect_uri,
|
|
"code_verifier": code_verifier,
|
|
},
|
|
headers={
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Authorization": f"Basic {creds}",
|
|
},
|
|
timeout=timeout,
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def refresh_access_token(
|
|
token_endpoint: str,
|
|
refresh_token: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
*,
|
|
timeout: float = 30.0,
|
|
) -> dict:
|
|
"""Refresh an expired access token."""
|
|
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
|
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
resp = await client.post(
|
|
token_endpoint,
|
|
data={
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
},
|
|
headers={
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Authorization": f"Basic {creds}",
|
|
},
|
|
timeout=timeout,
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|