mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
add automatic token refresh for MCP OAuth connectors
This commit is contained in:
parent
81711c9e5b
commit
9b78fbfe15
1 changed files with 121 additions and 3 deletions
|
|
@ -377,6 +377,118 @@ async def _load_http_mcp_tools(
|
|||
return tools
|
||||
|
||||
|
||||
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
|
||||
|
||||
|
||||
async def _maybe_refresh_mcp_oauth_token(
|
||||
session: AsyncSession,
|
||||
connector: "SearchSourceConnector",
|
||||
cfg: dict[str, Any],
|
||||
server_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Refresh the access token for an MCP OAuth connector if it is about to expire.
|
||||
|
||||
Returns the (possibly updated) ``server_config``.
|
||||
"""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||
expires_at_str = mcp_oauth.get("expires_at")
|
||||
if not expires_at_str:
|
||||
return server_config
|
||||
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at.tzinfo is None:
|
||||
from datetime import timezone
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS):
|
||||
return server_config
|
||||
except (ValueError, TypeError):
|
||||
return server_config
|
||||
|
||||
refresh_token = mcp_oauth.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"MCP connector %s token expired but no refresh_token available",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
try:
|
||||
from app.config import config as app_config
|
||||
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
enc = TokenEncryption(app_config.SECRET_KEY)
|
||||
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||
decrypted_secret = (
|
||||
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||
if mcp_oauth.get("client_secret")
|
||||
else ""
|
||||
)
|
||||
|
||||
token_json = await refresh_access_token(
|
||||
token_endpoint=mcp_oauth["token_endpoint"],
|
||||
refresh_token=decrypted_refresh,
|
||||
client_id=mcp_oauth["client_id"],
|
||||
client_secret=decrypted_secret,
|
||||
)
|
||||
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
logger.warning(
|
||||
"MCP connector %s token refresh returned no access_token",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
new_expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
new_expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_oauth = dict(mcp_oauth)
|
||||
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||
if token_json.get("refresh_token"):
|
||||
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||
token_json["refresh_token"]
|
||||
)
|
||||
updated_oauth["expires_at"] = (
|
||||
new_expires_at.isoformat() if new_expires_at else None
|
||||
)
|
||||
|
||||
updated_server_config = dict(server_config)
|
||||
updated_server_config["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {new_access}",
|
||||
}
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
connector.config = {
|
||||
**cfg,
|
||||
"server_config": updated_server_config,
|
||||
"mcp_oauth": updated_oauth,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
||||
return updated_server_config
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh MCP OAuth token for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return server_config
|
||||
|
||||
|
||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||
"""Invalidate cached MCP tools.
|
||||
|
||||
|
|
@ -429,9 +541,9 @@ async def load_mcp_tools(
|
|||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
try:
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
trusted_tools = config.get("trusted_tools", [])
|
||||
cfg = connector.config or {}
|
||||
server_config = cfg.get("server_config", {})
|
||||
trusted_tools = cfg.get("trusted_tools", [])
|
||||
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
|
|
@ -439,6 +551,12 @@ async def load_mcp_tools(
|
|||
)
|
||||
continue
|
||||
|
||||
# Refresh OAuth token for MCP OAuth connectors before connecting
|
||||
if cfg.get("mcp_oauth"):
|
||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||
session, connector, cfg, server_config,
|
||||
)
|
||||
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue