fix: harden MCP OAuth and connector edge cases

This commit is contained in:
CREDO23 2026-04-22 20:54:42 +02:00
parent 01153b0d7e
commit 0eae96bffb
4 changed files with 25 additions and 12 deletions

View file

@ -130,8 +130,8 @@ def request_approval(
try:
decision_type, edited_params = _parse_decision(approval)
except ValueError:
logger.warning("No approval decision received for %s", tool_name)
return HITLResult(rejected=False, decision_type="error", params=params)
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
return HITLResult(rejected=True, decision_type="error", params=params)
logger.info("User decision for %s: %s", tool_name, decision_type)

View file

@ -447,7 +447,7 @@ def _get_token_enc() -> TokenEncryption:
def _inject_oauth_headers(
cfg: dict[str, Any],
server_config: dict[str, Any],
) -> dict[str, Any]:
) -> dict[str, Any] | None:
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
The DB never stores plaintext tokens in ``server_config.headers``. This
@ -469,11 +469,11 @@ def _inject_oauth_headers(
}
return result
except Exception:
logger.warning(
"Failed to decrypt MCP OAuth token for runtime injection",
logger.error(
"Failed to decrypt MCP OAuth token — connector will be skipped",
exc_info=True,
)
return server_config
return None
async def _maybe_refresh_mcp_oauth_token(
@ -666,7 +666,6 @@ async def load_mcp_tools(
try:
cfg = connector.config or {}
server_config = cfg.get("server_config", {})
trusted_tools = cfg.get("trusted_tools", [])
if not server_config or not isinstance(server_config, dict):
logger.warning(
@ -685,6 +684,14 @@ async def load_mcp_tools(
# Re-read cfg after potential refresh (connector was reloaded from DB).
cfg = connector.config or {}
server_config = _inject_oauth_headers(cfg, server_config)
if server_config is None:
logger.warning(
"Skipping MCP connector %d — OAuth token decryption failed",
connector.id,
)
continue
trusted_tools = cfg.get("trusted_tools", [])
ct = (
connector.connector_type.value
@ -692,7 +699,6 @@ async def load_mcp_tools(
else str(connector.connector_type)
)
# Resolve the allowlist from the service registry (if any).
svc_cfg = get_service_by_connector_type(ct)
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()

View file

@ -361,7 +361,12 @@ async def mcp_oauth_callback(
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
if account_meta:
connector_config.update(account_meta)
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
"workspace_id", "workspace_name", "organization_name",
"organization_url_key", "cloud_id", "site_name", "base_url"}
for k, v in account_meta.items():
if k in _SAFE_META_KEYS:
connector_config[k] = v
logger.info(
"Stored account metadata for %s: display_name=%s",
svc_key, account_meta.get("display_name", ""),

View file

@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
"""Get a friendly display name for a connector type."""
return BASE_NAME_FOR_TYPE.get(
connector_type, connector_type.replace("_", " ").title()
connector_type, connector_type.value.replace("_", " ").title()
)
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
base = get_base_name_for_type(connector_type)
if identifier:
return f"{base} - {identifier}"
name = f"{base} - {identifier}"
return await ensure_unique_connector_name(
session, name, search_space_id, user_id,
)
# Fallback: use counter for uniqueness
count = await count_connectors_of_type(
session, connector_type, search_space_id, user_id
)