add automatic token refresh for MCP OAuth connectors

This commit is contained in:
CREDO23 2026-04-21 21:20:12 +02:00
parent 81711c9e5b
commit 9b78fbfe15

View file

@ -377,6 +377,118 @@ async def _load_http_mcp_tools(
return tools
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
async def _maybe_refresh_mcp_oauth_token(
session: AsyncSession,
connector: "SearchSourceConnector",
cfg: dict[str, Any],
server_config: dict[str, Any],
) -> dict[str, Any]:
"""Refresh the access token for an MCP OAuth connector if it is about to expire.
Returns the (possibly updated) ``server_config``.
"""
from datetime import UTC, datetime, timedelta
mcp_oauth = cfg.get("mcp_oauth", {})
expires_at_str = mcp_oauth.get("expires_at")
if not expires_at_str:
return server_config
try:
expires_at = datetime.fromisoformat(expires_at_str)
if expires_at.tzinfo is None:
from datetime import timezone
expires_at = expires_at.replace(tzinfo=timezone.utc)
if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS):
return server_config
except (ValueError, TypeError):
return server_config
refresh_token = mcp_oauth.get("refresh_token")
if not refresh_token:
logger.warning(
"MCP connector %s token expired but no refresh_token available",
connector.id,
)
return server_config
try:
from app.config import config as app_config
from app.services.mcp_oauth.discovery import refresh_access_token
from app.utils.oauth_security import TokenEncryption
enc = TokenEncryption(app_config.SECRET_KEY)
decrypted_refresh = enc.decrypt_token(refresh_token)
decrypted_secret = (
enc.decrypt_token(mcp_oauth["client_secret"])
if mcp_oauth.get("client_secret")
else ""
)
token_json = await refresh_access_token(
token_endpoint=mcp_oauth["token_endpoint"],
refresh_token=decrypted_refresh,
client_id=mcp_oauth["client_id"],
client_secret=decrypted_secret,
)
new_access = token_json.get("access_token")
if not new_access:
logger.warning(
"MCP connector %s token refresh returned no access_token",
connector.id,
)
return server_config
new_expires_at = None
if token_json.get("expires_in"):
new_expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
updated_oauth = dict(mcp_oauth)
updated_oauth["access_token"] = enc.encrypt_token(new_access)
if token_json.get("refresh_token"):
updated_oauth["refresh_token"] = enc.encrypt_token(
token_json["refresh_token"]
)
updated_oauth["expires_at"] = (
new_expires_at.isoformat() if new_expires_at else None
)
updated_server_config = dict(server_config)
updated_server_config["headers"] = {
**server_config.get("headers", {}),
"Authorization": f"Bearer {new_access}",
}
from sqlalchemy.orm.attributes import flag_modified
connector.config = {
**cfg,
"server_config": updated_server_config,
"mcp_oauth": updated_oauth,
}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
return updated_server_config
except Exception:
logger.warning(
"Failed to refresh MCP OAuth token for connector %s",
connector.id,
exc_info=True,
)
return server_config
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
@ -429,9 +541,9 @@ async def load_mcp_tools(
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
config = connector.config or {}
server_config = config.get("server_config", {})
trusted_tools = config.get("trusted_tools", [])
cfg = connector.config or {}
server_config = cfg.get("server_config", {})
trusted_tools = cfg.get("trusted_tools", [])
if not server_config or not isinstance(server_config, dict):
logger.warning(
@ -439,6 +551,12 @@ async def load_mcp_tools(
)
continue
# Refresh OAuth token for MCP OAuth connectors before connecting
if cfg.get("mcp_oauth"):
server_config = await _maybe_refresh_mcp_oauth_token(
session, connector, cfg, server_config,
)
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):