mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-29 19:35:20 +02:00
add automatic token refresh for MCP OAuth connectors
This commit is contained in:
parent
81711c9e5b
commit
9b78fbfe15
1 changed files with 121 additions and 3 deletions
|
|
@ -377,6 +377,118 @@ async def _load_http_mcp_tools(
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
|
||||||
|
|
||||||
|
|
||||||
|
async def _maybe_refresh_mcp_oauth_token(
|
||||||
|
session: AsyncSession,
|
||||||
|
connector: "SearchSourceConnector",
|
||||||
|
cfg: dict[str, Any],
|
||||||
|
server_config: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Refresh the access token for an MCP OAuth connector if it is about to expire.
|
||||||
|
|
||||||
|
Returns the (possibly updated) ``server_config``.
|
||||||
|
"""
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||||
|
expires_at_str = mcp_oauth.get("expires_at")
|
||||||
|
if not expires_at_str:
|
||||||
|
return server_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
expires_at = datetime.fromisoformat(expires_at_str)
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
from datetime import timezone
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS):
|
||||||
|
return server_config
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return server_config
|
||||||
|
|
||||||
|
refresh_token = mcp_oauth.get("refresh_token")
|
||||||
|
if not refresh_token:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %s token expired but no refresh_token available",
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
return server_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
enc = TokenEncryption(app_config.SECRET_KEY)
|
||||||
|
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||||
|
decrypted_secret = (
|
||||||
|
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||||
|
if mcp_oauth.get("client_secret")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
token_json = await refresh_access_token(
|
||||||
|
token_endpoint=mcp_oauth["token_endpoint"],
|
||||||
|
refresh_token=decrypted_refresh,
|
||||||
|
client_id=mcp_oauth["client_id"],
|
||||||
|
client_secret=decrypted_secret,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_access = token_json.get("access_token")
|
||||||
|
if not new_access:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %s token refresh returned no access_token",
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
return server_config
|
||||||
|
|
||||||
|
new_expires_at = None
|
||||||
|
if token_json.get("expires_in"):
|
||||||
|
new_expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=int(token_json["expires_in"])
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_oauth = dict(mcp_oauth)
|
||||||
|
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||||
|
if token_json.get("refresh_token"):
|
||||||
|
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||||
|
token_json["refresh_token"]
|
||||||
|
)
|
||||||
|
updated_oauth["expires_at"] = (
|
||||||
|
new_expires_at.isoformat() if new_expires_at else None
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_server_config = dict(server_config)
|
||||||
|
updated_server_config["headers"] = {
|
||||||
|
**server_config.get("headers", {}),
|
||||||
|
"Authorization": f"Bearer {new_access}",
|
||||||
|
}
|
||||||
|
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
connector.config = {
|
||||||
|
**cfg,
|
||||||
|
"server_config": updated_server_config,
|
||||||
|
"mcp_oauth": updated_oauth,
|
||||||
|
}
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(connector)
|
||||||
|
|
||||||
|
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
||||||
|
return updated_server_config
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to refresh MCP OAuth token for connector %s",
|
||||||
|
connector.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return server_config
|
||||||
|
|
||||||
|
|
||||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
"""Invalidate cached MCP tools.
|
"""Invalidate cached MCP tools.
|
||||||
|
|
||||||
|
|
@ -429,9 +541,9 @@ async def load_mcp_tools(
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
for connector in result.scalars():
|
for connector in result.scalars():
|
||||||
try:
|
try:
|
||||||
config = connector.config or {}
|
cfg = connector.config or {}
|
||||||
server_config = config.get("server_config", {})
|
server_config = cfg.get("server_config", {})
|
||||||
trusted_tools = config.get("trusted_tools", [])
|
trusted_tools = cfg.get("trusted_tools", [])
|
||||||
|
|
||||||
if not server_config or not isinstance(server_config, dict):
|
if not server_config or not isinstance(server_config, dict):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -439,6 +551,12 @@ async def load_mcp_tools(
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Refresh OAuth token for MCP OAuth connectors before connecting
|
||||||
|
if cfg.get("mcp_oauth"):
|
||||||
|
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||||
|
session, connector, cfg, server_config,
|
||||||
|
)
|
||||||
|
|
||||||
transport = server_config.get("transport", "stdio")
|
transport = server_config.get("transport", "stdio")
|
||||||
|
|
||||||
if transport in ("streamable-http", "http", "sse"):
|
if transport in ("streamable-http", "http", "sse"):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue