fix: harden MCP OAuth and connector edge cases

This commit is contained in:
CREDO23 2026-04-22 20:54:42 +02:00
parent 01153b0d7e
commit 0eae96bffb
4 changed files with 25 additions and 12 deletions

View file

@ -130,8 +130,8 @@ def request_approval(
try:
decision_type, edited_params = _parse_decision(approval)
except ValueError:
logger.warning("No approval decision received for %s", tool_name)
return HITLResult(rejected=False, decision_type="error", params=params)
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
return HITLResult(rejected=True, decision_type="error", params=params)
logger.info("User decision for %s: %s", tool_name, decision_type)

View file

@ -447,7 +447,7 @@ def _get_token_enc() -> TokenEncryption:
def _inject_oauth_headers(
cfg: dict[str, Any],
server_config: dict[str, Any],
) -> dict[str, Any]:
) -> dict[str, Any] | None:
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
The DB never stores plaintext tokens in ``server_config.headers``. This
@ -469,11 +469,11 @@ def _inject_oauth_headers(
}
return result
except Exception:
logger.warning(
"Failed to decrypt MCP OAuth token for runtime injection",
logger.error(
"Failed to decrypt MCP OAuth token — connector will be skipped",
exc_info=True,
)
return server_config
return None
async def _maybe_refresh_mcp_oauth_token(
@ -666,7 +666,6 @@ async def load_mcp_tools(
try:
cfg = connector.config or {}
server_config = cfg.get("server_config", {})
trusted_tools = cfg.get("trusted_tools", [])
if not server_config or not isinstance(server_config, dict):
logger.warning(
@ -685,6 +684,14 @@ async def load_mcp_tools(
# Re-read cfg after potential refresh (connector was reloaded from DB).
cfg = connector.config or {}
server_config = _inject_oauth_headers(cfg, server_config)
if server_config is None:
logger.warning(
"Skipping MCP connector %d — OAuth token decryption failed",
connector.id,
)
continue
trusted_tools = cfg.get("trusted_tools", [])
ct = (
connector.connector_type.value
@ -692,7 +699,6 @@ async def load_mcp_tools(
else str(connector.connector_type)
)
# Resolve the allowlist from the service registry (if any).
svc_cfg = get_service_by_connector_type(ct)
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()