mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
fix: encrypt tokens at rest, invalidate cache on refresh, clean up logging
This commit is contained in:
parent
4915675f45
commit
a1d03da896
3 changed files with 103 additions and 75 deletions
|
|
@ -22,7 +22,7 @@ from typing import Any
|
|||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from sqlalchemy import cast, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -66,18 +66,14 @@ def _create_dynamic_input_model_from_schema(
|
|||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
from typing import Any as AnyType
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
if is_required:
|
||||
field_definitions[param_name] = (
|
||||
AnyType,
|
||||
Any,
|
||||
Field(..., description=param_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[param_name] = (
|
||||
AnyType | None,
|
||||
Any | None,
|
||||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
|
|
@ -103,13 +99,13 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
logger.debug("MCP tool '%s' called", tool_name)
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
|
|
@ -133,13 +129,11 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
except Exception as e:
|
||||
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
|
|
@ -154,7 +148,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
|
||||
logger.debug("Created MCP tool (stdio): '%s'", tool_name)
|
||||
return tool
|
||||
|
||||
|
||||
|
|
@ -191,13 +185,13 @@ async def _create_mcp_tool_from_definition_http(
|
|||
if tool_name_prefix:
|
||||
tool_description = f"[Account: {connector_name}] {tool_description}"
|
||||
|
||||
logger.info(f"MCP HTTP tool '{exposed_name}' input schema: {input_schema}")
|
||||
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{exposed_name}' called with params: {kwargs}")
|
||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||
|
||||
if is_readonly:
|
||||
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
|
@ -238,15 +232,12 @@ async def _create_mcp_tool_from_definition_http(
|
|||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info(
|
||||
f"MCP HTTP tool '{exposed_name}' succeeded: {result_str[:200]}"
|
||||
)
|
||||
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
|
||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=exposed_name,
|
||||
|
|
@ -264,7 +255,7 @@ async def _create_mcp_tool_from_definition_http(
|
|||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (HTTP): '{exposed_name}'")
|
||||
logger.debug("Created MCP tool (HTTP): '%s'", exposed_name)
|
||||
return tool
|
||||
|
||||
|
||||
|
|
@ -280,21 +271,24 @@ async def _load_stdio_mcp_tools(
|
|||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping"
|
||||
"MCP connector %d (name: '%s') missing or invalid command field, skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid args field (must be list), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid env field (must be dict), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
|
|
@ -304,8 +298,8 @@ async def _load_stdio_mcp_tools(
|
|||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from stdio MCP server "
|
||||
f"'{command}' (connector {connector_id})"
|
||||
"Discovered %d tools from stdio MCP server '%s' (connector %d)",
|
||||
len(tool_definitions), command, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
|
|
@ -320,8 +314,8 @@ async def _load_stdio_mcp_tools(
|
|||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
"Failed to create tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
@ -351,14 +345,16 @@ async def _load_http_mcp_tools(
|
|||
url = server_config.get("url")
|
||||
if not url or not isinstance(url, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping"
|
||||
"MCP connector %d (name: '%s') missing or invalid url field, skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
headers = server_config.get("headers", {})
|
||||
if not isinstance(headers, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
|
|
@ -415,13 +411,14 @@ async def _load_http_mcp_tools(
|
|||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create HTTP tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
"Failed to create HTTP tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}"
|
||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||
url, connector_id, e,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
@ -430,6 +427,42 @@ async def _load_http_mcp_tools(
|
|||
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
|
||||
|
||||
|
||||
def _inject_oauth_headers(
|
||||
cfg: dict[str, Any],
|
||||
server_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
|
||||
|
||||
The DB never stores plaintext tokens in ``server_config.headers``. This
|
||||
function decrypts ``mcp_oauth.access_token`` at runtime and returns a
|
||||
*copy* of ``server_config`` with the Authorization header set.
|
||||
"""
|
||||
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||
encrypted_token = mcp_oauth.get("access_token")
|
||||
if not encrypted_token:
|
||||
return server_config
|
||||
|
||||
try:
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
enc = TokenEncryption(app_config.SECRET_KEY)
|
||||
access_token = enc.decrypt_token(encrypted_token)
|
||||
|
||||
result = dict(server_config)
|
||||
result["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
return result
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decrypt MCP OAuth token for runtime injection",
|
||||
exc_info=True,
|
||||
)
|
||||
return server_config
|
||||
|
||||
|
||||
async def _maybe_refresh_mcp_oauth_token(
|
||||
session: AsyncSession,
|
||||
connector: "SearchSourceConnector",
|
||||
|
|
@ -510,17 +543,11 @@ async def _maybe_refresh_mcp_oauth_token(
|
|||
new_expires_at.isoformat() if new_expires_at else None
|
||||
)
|
||||
|
||||
updated_server_config = dict(server_config)
|
||||
updated_server_config["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {new_access}",
|
||||
}
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
connector.config = {
|
||||
**cfg,
|
||||
"server_config": updated_server_config,
|
||||
"server_config": server_config,
|
||||
"mcp_oauth": updated_oauth,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
|
|
@ -528,7 +555,17 @@ async def _maybe_refresh_mcp_oauth_token(
|
|||
await session.refresh(connector)
|
||||
|
||||
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
||||
return updated_server_config
|
||||
|
||||
# Invalidate cache so next call picks up the new token.
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
# Return server_config with the fresh token injected for immediate use.
|
||||
refreshed_config = dict(server_config)
|
||||
refreshed_config["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {new_access}",
|
||||
}
|
||||
return refreshed_config
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
|
|
@ -622,15 +659,21 @@ async def load_mcp_tools(
|
|||
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||
"MCP connector %d (name: '%s') has invalid or missing server_config, skipping",
|
||||
connector.id, connector.name,
|
||||
)
|
||||
continue
|
||||
|
||||
# Refresh OAuth token for MCP OAuth connectors before connecting
|
||||
# For MCP OAuth connectors: refresh if needed, then decrypt the
|
||||
# access token and inject it into headers at runtime. The DB
|
||||
# intentionally does NOT store plaintext tokens in server_config.
|
||||
if cfg.get("mcp_oauth"):
|
||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||
session, connector, cfg, server_config,
|
||||
)
|
||||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
||||
cfg = connector.config or {}
|
||||
server_config = _inject_oauth_headers(cfg, server_config)
|
||||
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
|
|
@ -677,7 +720,8 @@ async def load_mcp_tools(
|
|||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||
"Failed to load tools from MCP connector %d: %s",
|
||||
connector.id, e,
|
||||
)
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
|
@ -686,9 +730,9 @@ async def load_mcp_tools(
|
|||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||
del _mcp_tools_cache[oldest_key]
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id)
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load MCP tools: {e!s}")
|
||||
logger.exception("Failed to load MCP tools: %s", e)
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -78,11 +78,7 @@ from .google_drive import (
|
|||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
# NOTE: Native Jira CRUD tools (create/update/delete_jira_issue) have been
|
||||
# replaced by MCP equivalents (createJiraIssue, editJiraIssue). The native
|
||||
# tools used the REST API which is incompatible with MCP-scoped OAuth tokens.
|
||||
from .connected_accounts import create_get_connected_accounts_tool
|
||||
# NOTE: Native Linear delete tool disabled — see comment in BUILTIN_TOOLS.
|
||||
from .luma import (
|
||||
create_create_luma_event_tool,
|
||||
create_list_luma_events_tool,
|
||||
|
|
@ -279,12 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS — create/update handled by MCP save_issue. Delete/archive
|
||||
# is NOT available: the official Linear MCP server does not expose a
|
||||
# delete tool, and the native tool's GraphQL API call fails with
|
||||
# MCP-scoped tokens (401). Re-enable when Linear adds MCP delete support.
|
||||
# =========================================================================
|
||||
# =========================================================================
|
||||
# NOTION TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
|
|
@ -518,11 +508,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# JIRA TOOLS — Now fully handled by MCP (createJiraIssue, editJiraIssue,
|
||||
# searchJiraIssuesUsingJql, etc.). Native tools removed because the
|
||||
# MCP-scoped OAuth token cannot call the Jira REST API.
|
||||
# =========================================================================
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Confluence connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
|
|
@ -843,14 +828,15 @@ async def build_tools_async(
|
|||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
"Registered %d MCP tools: %s",
|
||||
len(mcp_tools), [t.name for t in mcp_tools],
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail - just continue without MCP tools
|
||||
logging.exception(f"Failed to load MCP tools: {e!s}")
|
||||
logging.exception("Failed to load MCP tools: %s", e)
|
||||
|
||||
logging.info(
|
||||
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
||||
"Total tools for agent: %d — %s",
|
||||
len(tools), [t.name for t in tools],
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue