diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 9743d049d..cf3e51166 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -377,6 +377,118 @@ async def _load_http_mcp_tools( return tools +_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry + + +async def _maybe_refresh_mcp_oauth_token( + session: AsyncSession, + connector: "SearchSourceConnector", + cfg: dict[str, Any], + server_config: dict[str, Any], +) -> dict[str, Any]: + """Refresh the access token for an MCP OAuth connector if it is about to expire. + + Returns the (possibly updated) ``server_config``. + """ + from datetime import UTC, datetime, timedelta + + mcp_oauth = cfg.get("mcp_oauth", {}) + expires_at_str = mcp_oauth.get("expires_at") + if not expires_at_str: + return server_config + + try: + expires_at = datetime.fromisoformat(expires_at_str) + if expires_at.tzinfo is None: + from datetime import timezone + expires_at = expires_at.replace(tzinfo=timezone.utc) + + if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS): + return server_config + except (ValueError, TypeError): + return server_config + + refresh_token = mcp_oauth.get("refresh_token") + if not refresh_token: + logger.warning( + "MCP connector %s token expired but no refresh_token available", + connector.id, + ) + return server_config + + try: + from app.config import config as app_config + from app.services.mcp_oauth.discovery import refresh_access_token + from app.utils.oauth_security import TokenEncryption + + enc = TokenEncryption(app_config.SECRET_KEY) + decrypted_refresh = enc.decrypt_token(refresh_token) + decrypted_secret = ( + enc.decrypt_token(mcp_oauth["client_secret"]) + if mcp_oauth.get("client_secret") + else "" + ) + + token_json = await refresh_access_token( + token_endpoint=mcp_oauth["token_endpoint"], + refresh_token=decrypted_refresh, + client_id=mcp_oauth["client_id"], + client_secret=decrypted_secret, + ) + + new_access = token_json.get("access_token") + if not new_access: + logger.warning( + "MCP connector %s token refresh returned no access_token", + connector.id, + ) + return server_config + + new_expires_at = None + if token_json.get("expires_in"): + new_expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + updated_oauth = dict(mcp_oauth) + updated_oauth["access_token"] = enc.encrypt_token(new_access) + if token_json.get("refresh_token"): + updated_oauth["refresh_token"] = enc.encrypt_token( + token_json["refresh_token"] + ) + updated_oauth["expires_at"] = ( + new_expires_at.isoformat() if new_expires_at else None + ) + + updated_server_config = dict(server_config) + updated_server_config["headers"] = { + **server_config.get("headers", {}), + "Authorization": f"Bearer {new_access}", + } + + from sqlalchemy.orm.attributes import flag_modified + + connector.config = { + **cfg, + "server_config": updated_server_config, + "mcp_oauth": updated_oauth, + } + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + + logger.info("Refreshed MCP OAuth token for connector %s", connector.id) + return updated_server_config + + except Exception: + logger.warning( + "Failed to refresh MCP OAuth token for connector %s", + connector.id, + exc_info=True, + ) + return server_config + + def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: """Invalidate cached MCP tools. @@ -429,9 +541,9 @@ async def load_mcp_tools( tools: list[StructuredTool] = [] for connector in result.scalars(): try: - config = connector.config or {} - server_config = config.get("server_config", {}) - trusted_tools = config.get("trusted_tools", []) + cfg = connector.config or {} + server_config = cfg.get("server_config", {}) + trusted_tools = cfg.get("trusted_tools", []) if not server_config or not isinstance(server_config, dict): logger.warning( @@ -439,6 +551,12 @@ async def load_mcp_tools( ) continue + # Refresh OAuth token for MCP OAuth connectors before connecting + if cfg.get("mcp_oauth"): + server_config = await _maybe_refresh_mcp_oauth_token( + session, connector, cfg, server_config, + ) + transport = server_config.get("transport", "stdio") if transport in ("streamable-http", "http", "sse"):