fix: reactive 401 recovery for live MCP connectors and unified reauth endpoints

This commit is contained in:
CREDO23 2026-04-23 08:27:11 +02:00
parent 16f47578d7
commit e3172dc282
4 changed files with 396 additions and 178 deletions

View file

@ -194,6 +194,31 @@ async def _create_mcp_tool_from_definition_http(
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
async def _do_mcp_call(
call_headers: dict[str, str],
call_kwargs: dict[str, Any],
) -> str:
"""Execute a single MCP HTTP call with the given headers."""
async with (
streamablehttp_client(url, headers=call_headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
)
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
return "\n".join(result) if result else ""
async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport."""
logger.debug("MCP HTTP tool '%s' called", exposed_name)
@ -218,31 +243,46 @@ async def _create_mcp_tool_from_definition_http(
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
try:
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
result_str = await _do_mcp_call(headers, call_kwargs)
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
return result_str
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err)
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}"
logger.warning(
"MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s",
exposed_name, connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
return (
f"Error: MCP tool '{exposed_name}' authentication expired. "
"Please re-authenticate the connector in your settings."
)
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
try:
result_str = await _do_mcp_call(fresh_headers, call_kwargs)
logger.info(
"MCP HTTP tool '%s' succeeded after 401 recovery",
exposed_name,
)
return result_str
except Exception as e:
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
except Exception as retry_err:
logger.exception(
"MCP HTTP tool '%s' still failing after token refresh: %s",
exposed_name, retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return (
f"Error: MCP tool '{exposed_name}' authentication expired. "
"Please re-authenticate the connector in your settings."
)
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}"
tool = StructuredTool(
name=exposed_name,
@ -365,66 +405,98 @@ async def _load_http_mcp_tools(
allowed_set = set(allowed_tools) if allowed_tools else None
try:
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
"""Connect, initialize, and list tools from the MCP server."""
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.list_tools()
tool_definitions = []
for tool in response.tools:
tool_definitions.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
return [
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
for tool in response.tools
]
total_discovered = len(tool_definitions)
try:
tool_definitions = await _discover(headers)
except Exception as first_err:
if not _is_auth_error(first_err) or connector_id is None:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url, connector_id, first_err,
)
return tools
if allowed_set:
tool_definitions = [
td for td in tool_definitions if td["name"] in allowed_set
]
logger.info(
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
url, connector_id, len(tool_definitions), total_discovered,
)
else:
logger.info(
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
total_discovered, url, connector_id,
)
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition_http(
tool_def,
url,
headers,
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
tools.append(tool)
except Exception as e:
logger.exception(
"Failed to create HTTP tool '%s' from connector %d: %s",
tool_def.get("name"), connector_id, e,
)
except Exception as e:
logger.exception(
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url, connector_id, e,
logger.warning(
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
connector_id,
)
fresh_headers = await _force_refresh_and_get_headers(connector_id)
if fresh_headers is None:
await _mark_connector_auth_expired(connector_id)
logger.error(
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
connector_id,
)
return tools
try:
tool_definitions = await _discover(fresh_headers)
headers = fresh_headers
logger.info(
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
connector_id,
)
except Exception as retry_err:
logger.exception(
"HTTP MCP discovery for connector %d still failing after refresh: %s",
connector_id, retry_err,
)
if _is_auth_error(retry_err):
await _mark_connector_auth_expired(connector_id)
return tools
total_discovered = len(tool_definitions)
if allowed_set:
tool_definitions = [
td for td in tool_definitions if td["name"] in allowed_set
]
logger.info(
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
url, connector_id, len(tool_definitions), total_discovered,
)
else:
logger.info(
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
total_discovered, url, connector_id,
)
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition_http(
tool_def,
url,
headers,
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
tools.append(tool)
except Exception as e:
logger.exception(
"Failed to create HTTP tool '%s' from connector %d: %s",
tool_def.get("name"), connector_id, e,
)
return tools
@ -476,6 +548,91 @@ def _inject_oauth_headers(
return None
async def _refresh_connector_token(
session: AsyncSession,
connector: "SearchSourceConnector",
) -> str | None:
"""Refresh the OAuth token for an MCP connector and persist the result.
This is the shared core used by both proactive (pre-expiry) and reactive
(401 recovery) refresh paths. It handles:
- Decrypting the current refresh token / client secret
- Calling the token endpoint
- Encrypting and persisting the new tokens
- Clearing ``auth_expired`` if it was set
- Invalidating the MCP tools cache
Returns the **plaintext** new access token on success, or ``None`` on
failure (no refresh token, IdP error, etc.).
"""
from datetime import UTC, datetime, timedelta
from sqlalchemy.orm.attributes import flag_modified
from app.services.mcp_oauth.discovery import refresh_access_token
cfg = connector.config or {}
mcp_oauth = cfg.get("mcp_oauth", {})
refresh_token = mcp_oauth.get("refresh_token")
if not refresh_token:
logger.warning(
"MCP connector %s: no refresh_token available",
connector.id,
)
return None
enc = _get_token_enc()
decrypted_refresh = enc.decrypt_token(refresh_token)
decrypted_secret = (
enc.decrypt_token(mcp_oauth["client_secret"])
if mcp_oauth.get("client_secret")
else ""
)
token_json = await refresh_access_token(
token_endpoint=mcp_oauth["token_endpoint"],
refresh_token=decrypted_refresh,
client_id=mcp_oauth["client_id"],
client_secret=decrypted_secret,
)
new_access = token_json.get("access_token")
if not new_access:
logger.warning(
"MCP connector %s: token refresh returned no access_token",
connector.id,
)
return None
new_expires_at = None
if token_json.get("expires_in"):
new_expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
updated_oauth = dict(mcp_oauth)
updated_oauth["access_token"] = enc.encrypt_token(new_access)
if token_json.get("refresh_token"):
updated_oauth["refresh_token"] = enc.encrypt_token(
token_json["refresh_token"]
)
updated_oauth["expires_at"] = (
new_expires_at.isoformat() if new_expires_at else None
)
updated_cfg = {**cfg, "mcp_oauth": updated_oauth}
updated_cfg.pop("auth_expired", None)
connector.config = updated_cfg
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
invalidate_mcp_tools_cache(connector.search_space_id)
return new_access
async def _maybe_refresh_mcp_oauth_token(
session: AsyncSession,
connector: "SearchSourceConnector",
@ -504,73 +661,13 @@ async def _maybe_refresh_mcp_oauth_token(
except (ValueError, TypeError):
return server_config
refresh_token = mcp_oauth.get("refresh_token")
if not refresh_token:
logger.warning(
"MCP connector %s token expired but no refresh_token available",
connector.id,
)
return server_config
try:
from app.services.mcp_oauth.discovery import refresh_access_token
enc = _get_token_enc()
decrypted_refresh = enc.decrypt_token(refresh_token)
decrypted_secret = (
enc.decrypt_token(mcp_oauth["client_secret"])
if mcp_oauth.get("client_secret")
else ""
)
token_json = await refresh_access_token(
token_endpoint=mcp_oauth["token_endpoint"],
refresh_token=decrypted_refresh,
client_id=mcp_oauth["client_id"],
client_secret=decrypted_secret,
)
new_access = token_json.get("access_token")
new_access = await _refresh_connector_token(session, connector)
if not new_access:
logger.warning(
"MCP connector %s token refresh returned no access_token",
connector.id,
)
return server_config
new_expires_at = None
if token_json.get("expires_in"):
new_expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id)
updated_oauth = dict(mcp_oauth)
updated_oauth["access_token"] = enc.encrypt_token(new_access)
if token_json.get("refresh_token"):
updated_oauth["refresh_token"] = enc.encrypt_token(
token_json["refresh_token"]
)
updated_oauth["expires_at"] = (
new_expires_at.isoformat() if new_expires_at else None
)
from sqlalchemy.orm.attributes import flag_modified
connector.config = {
**cfg,
"server_config": server_config,
"mcp_oauth": updated_oauth,
}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
# Invalidate cache so next call picks up the new token.
invalidate_mcp_tools_cache(connector.search_space_id)
# Return server_config with the fresh token injected for immediate use.
refreshed_config = dict(server_config)
refreshed_config["headers"] = {
**server_config.get("headers", {}),
@ -587,6 +684,117 @@ async def _maybe_refresh_mcp_oauth_token(
return server_config
# ---------------------------------------------------------------------------
# Reactive 401 handling helpers
# ---------------------------------------------------------------------------
def _is_auth_error(exc: Exception) -> bool:
"""Check if an exception indicates an HTTP 401 authentication failure."""
try:
import httpx
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code == 401
except ImportError:
pass
err_str = str(exc).lower()
return "401" in err_str or "unauthorized" in err_str
async def _force_refresh_and_get_headers(
connector_id: int,
) -> dict[str, str] | None:
"""Force-refresh OAuth token for a connector and return fresh HTTP headers.
Opens a **new** DB session so this can be called from inside tool closures
that don't have access to the original session.
Returns ``None`` when the connector is not OAuth-backed, has no
refresh token, or the refresh itself fails.
"""
from app.db import async_session_maker
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if not connector:
return None
cfg = connector.config or {}
if not cfg.get("mcp_oauth"):
return None
server_config = cfg.get("server_config", {})
new_access = await _refresh_connector_token(session, connector)
if not new_access:
return None
logger.info(
"Force-refreshed MCP OAuth token for connector %s (401 recovery)",
connector_id,
)
return {
**server_config.get("headers", {}),
"Authorization": f"Bearer {new_access}",
}
except Exception:
logger.warning(
"Failed to force-refresh MCP OAuth token for connector %s",
connector_id,
exc_info=True,
)
return None
async def _mark_connector_auth_expired(connector_id: int) -> None:
"""Set ``config.auth_expired = True`` so the frontend shows re-auth UI."""
from app.db import async_session_maker
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if not connector:
return
cfg = dict(connector.config or {})
if cfg.get("auth_expired"):
return
cfg["auth_expired"] = True
connector.config = cfg
from sqlalchemy.orm.attributes import flag_modified
flag_modified(connector, "config")
await session.commit()
logger.info(
"Marked MCP connector %s as auth_expired after unrecoverable 401",
connector_id,
)
invalidate_mcp_tools_cache(connector.search_space_id)
except Exception:
logger.warning(
"Failed to mark connector %s as auth_expired",
connector_id,
exc_info=True,
)
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.