mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
fix: harden MCP OAuth and connector edge cases
This commit is contained in:
parent
01153b0d7e
commit
0eae96bffb
4 changed files with 25 additions and 12 deletions
|
|
@ -130,8 +130,8 @@ def request_approval(
|
|||
try:
|
||||
decision_type, edited_params = _parse_decision(approval)
|
||||
except ValueError:
|
||||
logger.warning("No approval decision received for %s", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="error", params=params)
|
||||
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
|
||||
return HITLResult(rejected=True, decision_type="error", params=params)
|
||||
|
||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||
|
||||
|
|
|
|||
|
|
@ -447,7 +447,7 @@ def _get_token_enc() -> TokenEncryption:
|
|||
def _inject_oauth_headers(
|
||||
cfg: dict[str, Any],
|
||||
server_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
|
||||
|
||||
The DB never stores plaintext tokens in ``server_config.headers``. This
|
||||
|
|
@ -469,11 +469,11 @@ def _inject_oauth_headers(
|
|||
}
|
||||
return result
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decrypt MCP OAuth token for runtime injection",
|
||||
logger.error(
|
||||
"Failed to decrypt MCP OAuth token — connector will be skipped",
|
||||
exc_info=True,
|
||||
)
|
||||
return server_config
|
||||
return None
|
||||
|
||||
|
||||
async def _maybe_refresh_mcp_oauth_token(
|
||||
|
|
@ -666,7 +666,6 @@ async def load_mcp_tools(
|
|||
try:
|
||||
cfg = connector.config or {}
|
||||
server_config = cfg.get("server_config", {})
|
||||
trusted_tools = cfg.get("trusted_tools", [])
|
||||
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
|
|
@ -685,6 +684,14 @@ async def load_mcp_tools(
|
|||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
||||
cfg = connector.config or {}
|
||||
server_config = _inject_oauth_headers(cfg, server_config)
|
||||
if server_config is None:
|
||||
logger.warning(
|
||||
"Skipping MCP connector %d — OAuth token decryption failed",
|
||||
connector.id,
|
||||
)
|
||||
continue
|
||||
|
||||
trusted_tools = cfg.get("trusted_tools", [])
|
||||
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
|
|
@ -692,7 +699,6 @@ async def load_mcp_tools(
|
|||
else str(connector.connector_type)
|
||||
)
|
||||
|
||||
# Resolve the allowlist from the service registry (if any).
|
||||
svc_cfg = get_service_by_connector_type(ct)
|
||||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||
|
|
|
|||
|
|
@ -361,7 +361,12 @@ async def mcp_oauth_callback(
|
|||
|
||||
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
|
||||
if account_meta:
|
||||
connector_config.update(account_meta)
|
||||
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
|
||||
"workspace_id", "workspace_name", "organization_name",
|
||||
"organization_url_key", "cloud_id", "site_name", "base_url"}
|
||||
for k, v in account_meta.items():
|
||||
if k in _SAFE_META_KEYS:
|
||||
connector_config[k] = v
|
||||
logger.info(
|
||||
"Stored account metadata for %s: display_name=%s",
|
||||
svc_key, account_meta.get("display_name", ""),
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
|
|||
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
|
||||
"""Get a friendly display name for a connector type."""
|
||||
return BASE_NAME_FOR_TYPE.get(
|
||||
connector_type, connector_type.replace("_", " ").title()
|
||||
connector_type, connector_type.value.replace("_", " ").title()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
|
|||
base = get_base_name_for_type(connector_type)
|
||||
|
||||
if identifier:
|
||||
return f"{base} - {identifier}"
|
||||
name = f"{base} - {identifier}"
|
||||
return await ensure_unique_connector_name(
|
||||
session, name, search_space_id, user_id,
|
||||
)
|
||||
|
||||
# Fallback: use counter for uniqueness
|
||||
count = await count_connectors_of_type(
|
||||
session, connector_type, search_space_id, user_id
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue