mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
fix: harden MCP OAuth and connector edge cases
This commit is contained in:
parent
01153b0d7e
commit
0eae96bffb
4 changed files with 25 additions and 12 deletions
|
|
@ -130,8 +130,8 @@ def request_approval(
|
||||||
try:
|
try:
|
||||||
decision_type, edited_params = _parse_decision(approval)
|
decision_type, edited_params = _parse_decision(approval)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning("No approval decision received for %s", tool_name)
|
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
|
||||||
return HITLResult(rejected=False, decision_type="error", params=params)
|
return HITLResult(rejected=True, decision_type="error", params=params)
|
||||||
|
|
||||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -447,7 +447,7 @@ def _get_token_enc() -> TokenEncryption:
|
||||||
def _inject_oauth_headers(
|
def _inject_oauth_headers(
|
||||||
cfg: dict[str, Any],
|
cfg: dict[str, Any],
|
||||||
server_config: dict[str, Any],
|
server_config: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any] | None:
|
||||||
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
|
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
|
||||||
|
|
||||||
The DB never stores plaintext tokens in ``server_config.headers``. This
|
The DB never stores plaintext tokens in ``server_config.headers``. This
|
||||||
|
|
@ -469,11 +469,11 @@ def _inject_oauth_headers(
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.error(
|
||||||
"Failed to decrypt MCP OAuth token for runtime injection",
|
"Failed to decrypt MCP OAuth token — connector will be skipped",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
return server_config
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_refresh_mcp_oauth_token(
|
async def _maybe_refresh_mcp_oauth_token(
|
||||||
|
|
@ -666,7 +666,6 @@ async def load_mcp_tools(
|
||||||
try:
|
try:
|
||||||
cfg = connector.config or {}
|
cfg = connector.config or {}
|
||||||
server_config = cfg.get("server_config", {})
|
server_config = cfg.get("server_config", {})
|
||||||
trusted_tools = cfg.get("trusted_tools", [])
|
|
||||||
|
|
||||||
if not server_config or not isinstance(server_config, dict):
|
if not server_config or not isinstance(server_config, dict):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -685,6 +684,14 @@ async def load_mcp_tools(
|
||||||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
||||||
cfg = connector.config or {}
|
cfg = connector.config or {}
|
||||||
server_config = _inject_oauth_headers(cfg, server_config)
|
server_config = _inject_oauth_headers(cfg, server_config)
|
||||||
|
if server_config is None:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping MCP connector %d — OAuth token decryption failed",
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
trusted_tools = cfg.get("trusted_tools", [])
|
||||||
|
|
||||||
ct = (
|
ct = (
|
||||||
connector.connector_type.value
|
connector.connector_type.value
|
||||||
|
|
@ -692,7 +699,6 @@ async def load_mcp_tools(
|
||||||
else str(connector.connector_type)
|
else str(connector.connector_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resolve the allowlist from the service registry (if any).
|
|
||||||
svc_cfg = get_service_by_connector_type(ct)
|
svc_cfg = get_service_by_connector_type(ct)
|
||||||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||||
|
|
|
||||||
|
|
@ -361,7 +361,12 @@ async def mcp_oauth_callback(
|
||||||
|
|
||||||
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
|
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
|
||||||
if account_meta:
|
if account_meta:
|
||||||
connector_config.update(account_meta)
|
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
|
||||||
|
"workspace_id", "workspace_name", "organization_name",
|
||||||
|
"organization_url_key", "cloud_id", "site_name", "base_url"}
|
||||||
|
for k, v in account_meta.items():
|
||||||
|
if k in _SAFE_META_KEYS:
|
||||||
|
connector_config[k] = v
|
||||||
logger.info(
|
logger.info(
|
||||||
"Stored account metadata for %s: display_name=%s",
|
"Stored account metadata for %s: display_name=%s",
|
||||||
svc_key, account_meta.get("display_name", ""),
|
svc_key, account_meta.get("display_name", ""),
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
|
||||||
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
|
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
|
||||||
"""Get a friendly display name for a connector type."""
|
"""Get a friendly display name for a connector type."""
|
||||||
return BASE_NAME_FOR_TYPE.get(
|
return BASE_NAME_FOR_TYPE.get(
|
||||||
connector_type, connector_type.replace("_", " ").title()
|
connector_type, connector_type.value.replace("_", " ").title()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
|
||||||
base = get_base_name_for_type(connector_type)
|
base = get_base_name_for_type(connector_type)
|
||||||
|
|
||||||
if identifier:
|
if identifier:
|
||||||
return f"{base} - {identifier}"
|
name = f"{base} - {identifier}"
|
||||||
|
return await ensure_unique_connector_name(
|
||||||
|
session, name, search_space_id, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Fallback: use counter for uniqueness
|
|
||||||
count = await count_connectors_of_type(
|
count = await count_connectors_of_type(
|
||||||
session, connector_type, search_space_id, user_id
|
session, connector_type, search_space_id, user_id
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue