diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py index f7164eab3..efe928fd1 100644 --- a/surfsense_backend/app/routes/mcp_oauth_route.py +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -107,8 +107,8 @@ async def connect_mcp_service( metadata = await discover_oauth_metadata( svc.mcp_url, origin_override=svc.oauth_discovery_origin, ) - auth_endpoint = metadata.get("authorization_endpoint") - token_endpoint = metadata.get("token_endpoint") + auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") + token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") if not auth_endpoint or not token_endpoint: @@ -165,7 +165,7 @@ async def connect_mcp_service( "state": state, } if svc.scopes: - auth_params["scope"] = " ".join(svc.scopes) + auth_params[svc.scope_param] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" @@ -253,17 +253,27 @@ async def mcp_oauth_callback( ) access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + expires_in = token_json.get("expires_in") + scope = token_json.get("scope") + + if not access_token and "authed_user" in token_json: + authed = token_json["authed_user"] + access_token = authed.get("access_token") + refresh_token = refresh_token or authed.get("refresh_token") + scope = scope or authed.get("scope") + expires_in = expires_in or authed.get("expires_in") + if not access_token: raise HTTPException( status_code=400, detail=f"No access token received from {svc.name}.", ) - refresh_token = token_json.get("refresh_token") expires_at = None - if token_json.get("expires_in"): + if expires_in: expires_at = datetime.now(UTC) + timedelta( - seconds=int(token_json["expires_in"]) + seconds=int(expires_in) ) connector_config = { @@ -280,7 +290,7 @@ async def mcp_oauth_callback( "access_token": enc.encrypt_token(access_token), "refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None, "expires_at": expires_at.isoformat() if expires_at else None, - "scope": token_json.get("scope"), + "scope": scope, }, "_token_encrypted": True, } @@ -415,8 +425,8 @@ async def reauth_mcp_service( metadata = await discover_oauth_metadata( svc.mcp_url, origin_override=svc.oauth_discovery_origin, ) - auth_endpoint = metadata.get("authorization_endpoint") - token_endpoint = metadata.get("token_endpoint") + auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") + token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") registration_endpoint = metadata.get("registration_endpoint") if not auth_endpoint or not token_endpoint: @@ -478,7 +488,7 @@ async def reauth_mcp_service( "state": state, } if svc.scopes: - auth_params["scope"] = " ".join(svc.scopes) + auth_params[svc.scope_param] = " ".join(svc.scopes) auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py index 4d87ceb40..df6c6bb18 100644 --- a/surfsense_backend/app/services/mcp_oauth/registry.py +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -21,6 +21,9 @@ class MCPServiceConfig: client_id_env: str | None = None client_secret_env: str | None = None scopes: list[str] = field(default_factory=list) + scope_param: str = "scope" + auth_endpoint_override: str | None = None + token_endpoint_override: str | None = None MCP_SERVICES: dict[str, MCPServiceConfig] = { @@ -46,6 +49,9 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = { supports_dcr=False, client_id_env="SLACK_CLIENT_ID", client_secret_env="SLACK_CLIENT_SECRET", + scope_param="user_scope", + auth_endpoint_override="https://slack.com/oauth/v2/authorize", + token_endpoint_override="https://slack.com/api/oauth.v2.access", scopes=[ "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", "search:read.files", "search:read.users",