SurfSense/surfsense_backend/app/routes/mcp_oauth_route.py

601 lines
21 KiB
Python

"""Generic MCP OAuth 2.1 route for services with official MCP servers.
Handles the full flow: discovery → DCR → PKCE authorization → token exchange
→ MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
and Airtable.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.utils.connector_naming import generate_unique_connector_name
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair
logger = logging.getLogger(__name__)
router = APIRouter()
async def _fetch_account_metadata(
service_key: str, access_token: str, token_json: dict[str, Any],
) -> dict[str, Any]:
"""Fetch display-friendly account metadata after a successful token exchange.
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
call their standard REST/GraphQL APIs — metadata discovery for those
happens at runtime through MCP tools instead.
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
*can* call their APIs, so we extract metadata here.
Failures are logged but never block connector creation.
"""
from app.services.mcp_oauth.registry import MCP_SERVICES
svc = MCP_SERVICES.get(service_key)
if not svc or svc.supports_dcr:
return {}
import httpx
meta: dict[str, Any] = {}
try:
if service_key == "slack":
team_info = token_json.get("team", {})
meta["team_id"] = team_info.get("id", "")
# TODO: oauth.v2.user.access only returns team.id, not
# team.name. To populate team_name, add "team:read" scope
# and call GET /api/team.info here.
meta["team_name"] = team_info.get("name", "")
if meta["team_name"]:
meta["display_name"] = meta["team_name"]
elif meta["team_id"]:
meta["display_name"] = f"Slack ({meta['team_id']})"
elif service_key == "airtable":
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
"https://api.airtable.com/v0/meta/whoami",
headers={"Authorization": f"Bearer {access_token}"},
)
if resp.status_code == 200:
whoami = resp.json()
meta["user_id"] = whoami.get("id", "")
meta["user_email"] = whoami.get("email", "")
meta["display_name"] = whoami.get("email", "Airtable")
else:
logger.warning(
"Airtable whoami API returned %d (non-blocking)", resp.status_code,
)
except Exception:
logger.warning(
"Failed to fetch account metadata for %s (non-blocking)",
service_key,
exc_info=True,
)
return meta
_state_manager: OAuthStateManager | None = None
_token_encryption: TokenEncryption | None = None
def _get_state_manager() -> OAuthStateManager:
global _state_manager
if _state_manager is None:
if not config.SECRET_KEY:
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
_state_manager = OAuthStateManager(config.SECRET_KEY)
return _state_manager
def _get_token_encryption() -> TokenEncryption:
global _token_encryption
if _token_encryption is None:
if not config.SECRET_KEY:
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
_token_encryption = TokenEncryption(config.SECRET_KEY)
return _token_encryption
def _build_redirect_uri(service: str) -> str:
base = config.BACKEND_URL or "http://localhost:8000"
return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback"
def _frontend_redirect(
space_id: int | None,
*,
success: bool = False,
connector_id: int | None = None,
error: str | None = None,
service: str = "mcp",
) -> RedirectResponse:
if success and space_id:
qs = f"success=true&connector={service}-mcp-connector"
if connector_id:
qs += f"&connectorId={connector_id}"
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
)
if error and space_id:
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
)
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
# ---------------------------------------------------------------------------
# /add — start MCP OAuth flow
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/add")
async def connect_mcp_service(
service: str,
space_id: int,
user: User = Depends(current_active_user),
):
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
try:
from app.services.mcp_oauth.discovery import (
discover_oauth_metadata,
register_client,
)
metadata = await discover_oauth_metadata(
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
)
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
registration_endpoint = metadata.get("registration_endpoint")
if not auth_endpoint or not token_endpoint:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
)
redirect_uri = _build_redirect_uri(service)
if svc.supports_dcr and registration_endpoint:
dcr = await register_client(registration_endpoint, redirect_uri)
client_id = dcr.get("client_id")
client_secret = dcr.get("client_secret", "")
if not client_id:
raise HTTPException(
status_code=502,
detail=f"DCR for {svc.name} did not return a client_id.",
)
elif svc.client_id_env:
client_id = getattr(config, svc.client_id_env, None)
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
if not client_id:
raise HTTPException(
status_code=500,
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
)
else:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
)
verifier, challenge = generate_pkce_pair()
enc = _get_token_encryption()
state = _get_state_manager().generate_secure_state(
space_id,
user.id,
service=service,
code_verifier=verifier,
mcp_client_id=client_id,
mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "",
mcp_token_endpoint=token_endpoint,
mcp_url=svc.mcp_url,
)
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": redirect_uri,
"code_challenge": challenge,
"code_challenge_method": "S256",
"state": state,
}
if svc.scopes:
auth_params[svc.scope_param] = " ".join(svc.scopes)
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
logger.info(
"Generated %s MCP OAuth URL for user %s, space %s",
svc.name, user.id, space_id,
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate {service} MCP OAuth.",
) from e
# ---------------------------------------------------------------------------
# /callback — handle OAuth redirect
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/callback")
async def mcp_oauth_callback(
service: str,
code: str | None = None,
error: str | None = None,
state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
if error:
logger.warning("%s MCP OAuth error: %s", service, error)
space_id = None
if state:
try:
data = _get_state_manager().validate_state(state)
space_id = data.get("space_id")
except Exception:
pass
return _frontend_redirect(
space_id, error=f"{service}_mcp_oauth_denied", service=service,
)
if not code:
raise HTTPException(status_code=400, detail="Missing authorization code")
if not state:
raise HTTPException(status_code=400, detail="Missing state parameter")
data = _get_state_manager().validate_state(state)
user_id = UUID(data["user_id"])
space_id = data["space_id"]
svc_key = data.get("service", service)
if svc_key != service:
raise HTTPException(status_code=400, detail="State/path service mismatch")
from app.services.mcp_oauth.registry import get_service
svc = get_service(svc_key)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}")
try:
from app.services.mcp_oauth.discovery import exchange_code_for_tokens
enc = _get_token_encryption()
client_id = data["mcp_client_id"]
client_secret = (
enc.decrypt_token(data["mcp_client_secret"])
if data.get("mcp_client_secret")
else ""
)
token_endpoint = data["mcp_token_endpoint"]
code_verifier = data["code_verifier"]
mcp_url = data["mcp_url"]
redirect_uri = _build_redirect_uri(service)
token_json = await exchange_code_for_tokens(
token_endpoint=token_endpoint,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)
access_token = token_json.get("access_token")
refresh_token = token_json.get("refresh_token")
expires_in = token_json.get("expires_in")
scope = token_json.get("scope")
if not access_token and "authed_user" in token_json:
authed = token_json["authed_user"]
access_token = authed.get("access_token")
refresh_token = refresh_token or authed.get("refresh_token")
scope = scope or authed.get("scope")
expires_in = expires_in or authed.get("expires_in")
if not access_token:
raise HTTPException(
status_code=400,
detail=f"No access token received from {svc.name}.",
)
expires_at = None
if expires_in:
expires_at = datetime.now(UTC) + timedelta(
seconds=int(expires_in)
)
connector_config = {
"server_config": {
"transport": "streamable-http",
"url": mcp_url,
},
"mcp_service": svc_key,
"mcp_oauth": {
"client_id": client_id,
"client_secret": enc.encrypt_token(client_secret) if client_secret else "",
"token_endpoint": token_endpoint,
"access_token": enc.encrypt_token(access_token),
"refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None,
"expires_at": expires_at.isoformat() if expires_at else None,
"scope": scope,
},
"_token_encrypted": True,
}
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
if account_meta:
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
"workspace_id", "workspace_name", "organization_name",
"organization_url_key", "cloud_id", "site_name", "base_url"}
for k, v in account_meta.items():
if k in _SAFE_META_KEYS:
connector_config[k] = v
logger.info(
"Stored account metadata for %s: display_name=%s",
svc_key, account_meta.get("display_name", ""),
)
# ---- Re-auth path ----
db_connector_type = SearchSourceConnectorType(svc.connector_type)
reauth_connector_id = data.get("connector_id")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == db_connector_type,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found during re-auth",
)
db_connector.config = connector_config
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
_invalidate_cache(space_id)
logger.info(
"Re-authenticated %s MCP connector %s for user %s",
svc.name, db_connector.id, user_id,
)
reauth_return_url = data.get("return_url")
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return _frontend_redirect(
space_id, success=True, connector_id=db_connector.id, service=service,
)
# ---- New connector path ----
naming_identifier = account_meta.get("display_name")
connector_name = await generate_unique_connector_name(
session,
db_connector_type,
space_id,
user_id,
naming_identifier,
)
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=db_connector_type,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,
)
session.add(new_connector)
try:
await session.commit()
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409, detail="A connector for this service already exists.",
) from e
_invalidate_cache(space_id)
logger.info(
"Created %s MCP connector %s for user %s in space %s",
svc.name, new_connector.id, user_id, space_id,
)
return _frontend_redirect(
space_id, success=True, connector_id=new_connector.id, service=service,
)
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to complete %s MCP OAuth: %s", service, e, exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to complete {service} MCP OAuth.",
) from e
# ---------------------------------------------------------------------------
# /reauth — re-authenticate an existing MCP connector
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/reauth")
async def reauth_mcp_service(
service: str,
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
db_connector_type = SearchSourceConnectorType(svc.connector_type)
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == db_connector_type,
)
)
if not result.scalars().first():
raise HTTPException(
status_code=404, detail="Connector not found or access denied",
)
try:
from app.services.mcp_oauth.discovery import (
discover_oauth_metadata,
register_client,
)
metadata = await discover_oauth_metadata(
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
)
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
registration_endpoint = metadata.get("registration_endpoint")
if not auth_endpoint or not token_endpoint:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
)
redirect_uri = _build_redirect_uri(service)
if svc.supports_dcr and registration_endpoint:
dcr = await register_client(registration_endpoint, redirect_uri)
client_id = dcr.get("client_id")
client_secret = dcr.get("client_secret", "")
if not client_id:
raise HTTPException(
status_code=502,
detail=f"DCR for {svc.name} did not return a client_id.",
)
elif svc.client_id_env:
client_id = getattr(config, svc.client_id_env, None)
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
if not client_id:
raise HTTPException(
status_code=500,
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
)
else:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
)
verifier, challenge = generate_pkce_pair()
enc = _get_token_encryption()
extra: dict = {
"service": service,
"code_verifier": verifier,
"mcp_client_id": client_id,
"mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "",
"mcp_token_endpoint": token_endpoint,
"mcp_url": svc.mcp_url,
"connector_id": connector_id,
}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state = _get_state_manager().generate_secure_state(
space_id, user.id, **extra,
)
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": redirect_uri,
"code_challenge": challenge,
"code_challenge_method": "S256",
"state": state,
}
if svc.scopes:
auth_params[svc.scope_param] = " ".join(svc.scopes)
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
logger.info(
"Initiating %s MCP re-auth for user %s, connector %s",
svc.name, user.id, connector_id,
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to initiate {service} MCP re-auth.",
) from e
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _invalidate_cache(space_id: int) -> None:
try:
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(space_id)
except Exception:
logger.debug("MCP cache invalidation skipped", exc_info=True)