diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 4b204ffa9..89aa13620 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -314,6 +314,20 @@ async def create_surfsense_deep_agent( _t0 = time.perf_counter() _enabled_tool_names = {t.name for t in tools} _user_disabled_tool_names = set(disabled_tools) if disabled_tools else set() + + # Collect generic MCP connector info so the system prompt can route queries + # to their tools instead of falling back to "not in knowledge base". + _mcp_connector_tools: dict[str, list[str]] = {} + for t in tools: + meta = getattr(t, "metadata", None) or {} + if meta.get("mcp_is_generic") and meta.get("mcp_connector_name"): + _mcp_connector_tools.setdefault( + meta["mcp_connector_name"], [], + ).append(t.name) + + if _mcp_connector_tools: + _perf_log.info("MCP connector tool routing: %s", _mcp_connector_tools) + if agent_config is not None: system_prompt = build_configurable_system_prompt( custom_system_instructions=agent_config.system_instructions, @@ -322,12 +336,14 @@ async def create_surfsense_deep_agent( thread_visibility=thread_visibility, enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, + mcp_connector_tools=_mcp_connector_tools, ) else: system_prompt = build_surfsense_system_prompt( thread_visibility=thread_visibility, enabled_tool_names=_enabled_tool_names, disabled_tool_names=_user_disabled_tool_names, + mcp_connector_tools=_mcp_connector_tools, ) _perf_log.info( "[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0 diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 3182735d9..e77132182 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -815,11 +815,36 @@ Your goal is to provide helpful, informative answers in a clean, readable format """ +def _build_mcp_routing_block( + mcp_connector_tools: dict[str, list[str]] | None, +) -> str: + """Build an additional tool routing block for generic MCP connectors. + + When users add MCP servers (e.g. GitLab, GitHub), the LLM needs to know + those tools exist and should be called directly — not searched in the + knowledge base. + """ + if not mcp_connector_tools: + return "" + + lines = [ + "\n", + "You also have direct tools from these user-connected MCP servers.", + "Their data is NEVER in the knowledge base — call their tools directly.", + "", + ] + for server_name, tool_names in mcp_connector_tools.items(): + lines.append(f"- {server_name} → {', '.join(tool_names)}") + lines.append("\n") + return "\n".join(lines) + + def build_surfsense_system_prompt( today: datetime | None = None, thread_visibility: ChatVisibility | None = None, enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, ) -> str: """ Build the SurfSense system prompt with default settings. @@ -834,6 +859,9 @@ def build_surfsense_system_prompt( thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. + mcp_connector_tools: Mapping of MCP server display name → list of tool names + for generic MCP connectors. Injected into the system prompt so the LLM + knows to call these tools directly. Returns: Complete system prompt string @@ -841,6 +869,7 @@ def build_surfsense_system_prompt( visibility = thread_visibility or ChatVisibility.PRIVATE system_instructions = _get_system_instructions(visibility, today) + system_instructions += _build_mcp_routing_block(mcp_connector_tools) tools_instructions = _get_tools_instructions( visibility, enabled_tool_names, disabled_tool_names ) @@ -856,6 +885,7 @@ def build_configurable_system_prompt( thread_visibility: ChatVisibility | None = None, enabled_tool_names: set[str] | None = None, disabled_tool_names: set[str] | None = None, + mcp_connector_tools: dict[str, list[str]] | None = None, ) -> str: """ Build a configurable SurfSense system prompt based on NewLLMConfig settings. @@ -877,6 +907,9 @@ def build_configurable_system_prompt( thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None. enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included. disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user. + mcp_connector_tools: Mapping of MCP server display name → list of tool names + for generic MCP connectors. Injected into the system prompt so the LLM + knows to call these tools directly. Returns: Complete system prompt string @@ -894,6 +927,8 @@ def build_configurable_system_prompt( else: system_instructions = "" + system_instructions += _build_mcp_routing_block(mcp_connector_tools) + # Tools instructions: only include enabled tools, note disabled ones tools_instructions = _get_tools_instructions( thread_visibility, enabled_tool_names, disabled_tool_names diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py index 44c48344c..b46ddbcc5 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -45,6 +45,18 @@ class MCPClient: async def connect(self, max_retries: int = MAX_RETRIES): """Connect to the MCP server and manage its lifecycle. + Retries only apply to the **connection** phase (spawning the process, + initialising the session). Once the session is yielded to the caller, + any exception raised by the caller propagates normally -- the context + manager will NOT retry after ``yield``. + + Previous implementation wrapped both connection AND yield inside the + retry loop. Because ``@asynccontextmanager`` only allows a single + ``yield``, a failure after yield caused the generator to attempt a + second yield on retry, triggering + ``RuntimeError("generator didn't stop after athrow()")`` and orphaning + the stdio subprocess. + Args: max_retries: Maximum number of connection retry attempts @@ -57,26 +69,22 @@ class MCPClient: """ last_error = None delay = RETRY_DELAY + connected = False for attempt in range(max_retries): try: - # Merge env vars with current environment server_env = os.environ.copy() server_env.update(self.env) - # Create server parameters with env server_params = StdioServerParameters( command=self.command, args=self.args, env=server_env ) - # Spawn server process and create session - # Note: Cannot combine these context managers because ClientSession - # needs the read/write streams from stdio_client async with stdio_client(server=server_params) as (read, write): # noqa: SIM117 async with ClientSession(read, write) as session: - # Initialize the connection await session.initialize() self.session = session + connected = True if attempt > 0: logger.info( @@ -91,10 +99,16 @@ class MCPClient: self.command, " ".join(self.args), ) - yield session - return # Success, exit retry loop + try: + yield session + finally: + self.session = None + return except Exception as e: + self.session = None + if connected: + raise last_error = e if attempt < max_retries - 1: logger.warning( @@ -105,7 +119,7 @@ class MCPClient: delay, ) await asyncio.sleep(delay) - delay *= RETRY_BACKOFF # Exponential backoff + delay *= RETRY_BACKOFF else: logger.error( "Failed to connect to MCP server after %d attempts: %s", @@ -113,10 +127,7 @@ class MCPClient: e, exc_info=True, ) - finally: - self.session = None - # All retries exhausted error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts" if last_error: error_msg += f": {last_error}" @@ -161,12 +172,18 @@ class MCPClient: logger.error("Failed to list tools from MCP server: %s", e, exc_info=True) raise - async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + async def call_tool( + self, + tool_name: str, + arguments: dict[str, Any], + timeout: float = 60.0, + ) -> Any: """Call a tool on the MCP server. Args: tool_name: Name of the tool to call arguments: Arguments to pass to the tool + timeout: Maximum seconds to wait for the tool to respond Returns: Tool execution result @@ -185,10 +202,11 @@ class MCPClient: "Calling MCP tool '%s' with arguments: %s", tool_name, arguments ) - # Call tools/call RPC method - response = await self.session.call_tool(tool_name, arguments=arguments) + response = await asyncio.wait_for( + self.session.call_tool(tool_name, arguments=arguments), + timeout=timeout, + ) - # Extract content from response result = [] for content in response.content: if hasattr(content, "text"): @@ -202,15 +220,17 @@ class MCPClient: logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200]) return result_str + except asyncio.TimeoutError: + logger.error( + "MCP tool '%s' timed out after %.0fs", tool_name, timeout + ) + return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s" except RuntimeError as e: - # Handle validation errors from MCP server responses - # Some MCP servers (like server-memory) return extra fields not in their schema if "Invalid structured content" in str(e): logger.warning( "MCP server returned data not matching its schema, but continuing: %s", e, ) - # Try to extract result from error message or return a success message return "Operation completed (server returned unexpected format)" raise except (ValueError, TypeError, AttributeError, KeyError) as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 8f8e5007f..dfee24516 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -16,6 +16,7 @@ clicking "Always Allow", which adds the tool name to the connector's from __future__ import annotations +import asyncio import logging import time from collections import defaultdict @@ -27,7 +28,7 @@ if TYPE_CHECKING: from langchain_core.tools import StructuredTool from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, create_model from sqlalchemy import cast, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession @@ -41,6 +42,9 @@ logger = logging.getLogger(__name__) _MCP_CACHE_TTL_SECONDS = 300 # 5 minutes _MCP_CACHE_MAX_SIZE = 50 +_MCP_DISCOVERY_TIMEOUT_SECONDS = 30 +_TOOL_CALL_MAX_RETRIES = 3 +_TOOL_CALL_RETRY_DELAY = 1.5 # seconds, doubles per attempt _mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {} @@ -62,7 +66,18 @@ def _create_dynamic_input_model_from_schema( tool_name: str, input_schema: dict[str, Any], ) -> type[BaseModel]: - """Create a Pydantic model from MCP tool's JSON schema.""" + """Create a Pydantic model from MCP tool's JSON schema. + + Models always allow extra fields (``extra="allow"``) so that parameters + missing from a broken or incomplete JSON schema (e.g. ``zod-to-json-schema`` + producing an empty ``$schema``-only object) can still be forwarded to the + MCP server. + + When the schema declares **no** properties, a synthetic ``input_data`` + field of type ``dict`` is injected so the LLM has a visible parameter to + populate. The caller should unpack ``input_data`` before forwarding to + the MCP server (see ``_unpack_synthetic_input_data``). + """ properties = input_schema.get("properties", {}) required_fields = input_schema.get("required", []) @@ -82,8 +97,35 @@ def _create_dynamic_input_model_from_schema( Field(None, description=param_description), ) + if not properties: + field_definitions["input_data"] = ( + dict[str, Any] | None, + Field( + None, + description=( + "Arguments to pass to this tool as a JSON object. " + "Infer sensible key names from the tool name and description " + "(e.g. {\"search\": \"my query\"} for a search tool)." + ), + ), + ) + model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" - return create_model(model_name, **field_definitions) + model = create_model(model_name, __config__=ConfigDict(extra="allow"), **field_definitions) + return model + + +def _unpack_synthetic_input_data(kwargs: dict[str, Any]) -> dict[str, Any]: + """Unpack the synthetic ``input_data`` field into top-level kwargs. + + When the MCP tool schema is empty, ``_create_dynamic_input_model_from_schema`` + adds a catch-all ``input_data: dict`` field. This helper merges that dict + back into the top-level kwargs so the MCP server receives flat arguments. + """ + input_data = kwargs.pop("input_data", None) + if isinstance(input_data, dict): + kwargs.update(input_data) + return kwargs async def _create_mcp_tool_from_definition_stdio( @@ -101,7 +143,12 @@ async def _create_mcp_tool_from_definition_stdio( ``GraphInterrupt`` propagates cleanly to LangGraph. """ tool_name = tool_def.get("name", "unnamed_tool") - tool_description = tool_def.get("description", "No description provided") + raw_description = tool_def.get("description", "No description provided") + tool_description = ( + f"[MCP server: {connector_name}] {raw_description}" + if connector_name + else raw_description + ) input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema) @@ -119,7 +166,7 @@ async def _create_mcp_tool_from_definition_stdio( params=kwargs, context={ "mcp_server": connector_name, - "tool_description": tool_description, + "tool_description": raw_description, "mcp_transport": "stdio", "mcp_connector_id": connector_id, }, @@ -127,18 +174,32 @@ async def _create_mcp_tool_from_definition_stdio( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in hitl_result.params.items() if v is not None} + ) - try: - async with mcp_client.connect(): - result = await mcp_client.call_tool(tool_name, call_kwargs) - return str(result) - except RuntimeError as e: - logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e) - return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}" - except Exception as e: - logger.exception("MCP tool '%s' execution failed: %s", tool_name, e) - return f"Error: MCP tool '{tool_name}' execution failed: {e!s}" + last_error: Exception | None = None + for attempt in range(_TOOL_CALL_MAX_RETRIES): + try: + async with mcp_client.connect(): + result = await mcp_client.call_tool(tool_name, call_kwargs) + return str(result) + except Exception as e: + last_error = e + if attempt < _TOOL_CALL_MAX_RETRIES - 1: + delay = _TOOL_CALL_RETRY_DELAY * (2 ** attempt) + logger.warning( + "MCP tool '%s' failed (attempt %d/%d): %s. Retrying in %.1fs...", + tool_name, attempt + 1, _TOOL_CALL_MAX_RETRIES, e, delay, + ) + await asyncio.sleep(delay) + else: + logger.error( + "MCP tool '%s' failed after %d attempts: %s", + tool_name, _TOOL_CALL_MAX_RETRIES, e, exc_info=True, + ) + + return f"Error: MCP tool '{tool_name}' failed after {_TOOL_CALL_MAX_RETRIES} attempts: {last_error!s}" tool = StructuredTool( name=tool_name, @@ -148,6 +209,8 @@ async def _create_mcp_tool_from_definition_stdio( metadata={ "mcp_input_schema": input_schema, "mcp_transport": "stdio", + "mcp_connector_name": connector_name or None, + "mcp_is_generic": True, "hitl": True, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), }, @@ -167,6 +230,7 @@ async def _create_mcp_tool_from_definition_http( trusted_tools: list[str] | None = None, readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, + is_generic_mcp: bool = False, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). @@ -178,7 +242,7 @@ async def _create_mcp_tool_from_definition_http( but the actual MCP ``call_tool`` still uses the original name. """ original_tool_name = tool_def.get("name", "unnamed_tool") - tool_description = tool_def.get("description", "No description provided") + raw_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) is_readonly = readonly_tools is not None and original_tool_name in readonly_tools @@ -188,18 +252,51 @@ async def _create_mcp_tool_from_definition_http( else original_tool_name ) if tool_name_prefix: - tool_description = f"[Account: {connector_name}] {tool_description}" + tool_description = f"[Account: {connector_name}] {raw_description}" + elif is_generic_mcp and connector_name: + tool_description = f"[MCP server: {connector_name}] {raw_description}" + else: + tool_description = raw_description logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema) input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema) + async def _do_mcp_call( + call_headers: dict[str, str], + call_kwargs: dict[str, Any], + timeout: float = 60.0, + ) -> str: + """Execute a single MCP HTTP call with the given headers.""" + async with ( + streamablehttp_client(url, headers=call_headers) as (read, write, _), + ClientSession(read, write) as session, + ): + await session.initialize() + response = await asyncio.wait_for( + session.call_tool(original_tool_name, arguments=call_kwargs), + timeout=timeout, + ) + + result = [] + for content in response.content: + if hasattr(content, "text"): + result.append(content.text) + elif hasattr(content, "data"): + result.append(str(content.data)) + else: + result.append(str(content)) + + return "\n".join(result) if result else "" + async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" logger.debug("MCP HTTP tool '%s' called", exposed_name) if is_readonly: - call_kwargs = {k: v for k, v in kwargs.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in kwargs.items() if v is not None} + ) else: hitl_result = request_approval( action_type="mcp_tool_call", @@ -207,7 +304,7 @@ async def _create_mcp_tool_from_definition_http( params=kwargs, context={ "mcp_server": connector_name, - "tool_description": tool_description, + "tool_description": raw_description, "mcp_transport": "http", "mcp_connector_id": connector_id, }, @@ -215,34 +312,51 @@ async def _create_mcp_tool_from_definition_http( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} + call_kwargs = _unpack_synthetic_input_data( + {k: v for k, v in hitl_result.params.items() if v is not None} + ) try: - async with ( - streamablehttp_client(url, headers=headers) as (read, write, _), - ClientSession(read, write) as session, - ): - await session.initialize() - response = await session.call_tool( - original_tool_name, arguments=call_kwargs, + result_str = await _do_mcp_call(headers, call_kwargs) + logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) + return result_str + + except Exception as first_err: + if not _is_auth_error(first_err) or connector_id is None: + logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err) + return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}" + + logger.warning( + "MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s", + exposed_name, connector_id, + ) + fresh_headers = await _force_refresh_and_get_headers(connector_id) + if fresh_headers is None: + await _mark_connector_auth_expired(connector_id) + return ( + f"Error: MCP tool '{exposed_name}' authentication expired. " + "Please re-authenticate the connector in your settings." ) - result = [] - for content in response.content: - if hasattr(content, "text"): - result.append(content.text) - elif hasattr(content, "data"): - result.append(str(content.data)) - else: - result.append(str(content)) - - result_str = "\n".join(result) if result else "" - logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) + try: + result_str = await _do_mcp_call(fresh_headers, call_kwargs) + logger.info( + "MCP HTTP tool '%s' succeeded after 401 recovery", + exposed_name, + ) return result_str - - except Exception as e: - logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e) - return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}" + except Exception as retry_err: + logger.exception( + "MCP HTTP tool '%s' still failing after token refresh: %s", + exposed_name, retry_err, + ) + if _is_auth_error(retry_err): + await _mark_connector_auth_expired(connector_id) + return ( + f"Error: MCP tool '{exposed_name}' authentication expired. " + "Please re-authenticate the connector in your settings." + ) + return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}" tool = StructuredTool( name=exposed_name, @@ -253,6 +367,8 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, + "mcp_connector_name": connector_name or None, + "mcp_is_generic": is_generic_mcp, "hitl": not is_readonly, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), "mcp_original_tool_name": original_tool_name, @@ -334,6 +450,7 @@ async def _load_http_mcp_tools( allowed_tools: list[str] | None = None, readonly_tools: frozenset[str] | None = None, tool_name_prefix: str | None = None, + is_generic_mcp: bool = False, ) -> list[StructuredTool]: """Load tools from an HTTP-based MCP server. @@ -365,66 +482,99 @@ async def _load_http_mcp_tools( allowed_set = set(allowed_tools) if allowed_tools else None - try: + async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]: + """Connect, initialize, and list tools from the MCP server.""" async with ( - streamablehttp_client(url, headers=headers) as (read, write, _), + streamablehttp_client(url, headers=disc_headers) as (read, write, _), ClientSession(read, write) as session, ): await session.initialize() - response = await session.list_tools() - tool_definitions = [] - for tool in response.tools: - tool_definitions.append( - { - "name": tool.name, - "description": tool.description or "", - "input_schema": tool.inputSchema - if hasattr(tool, "inputSchema") - else {}, - } - ) + return [ + { + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.inputSchema + if hasattr(tool, "inputSchema") + else {}, + } + for tool in response.tools + ] - total_discovered = len(tool_definitions) + try: + tool_definitions = await _discover(headers) + except Exception as first_err: + if not _is_auth_error(first_err) or connector_id is None: + logger.exception( + "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", + url, connector_id, first_err, + ) + return tools - if allowed_set: - tool_definitions = [ - td for td in tool_definitions if td["name"] in allowed_set - ] - logger.info( - "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", - url, connector_id, len(tool_definitions), total_discovered, - ) - else: - logger.info( - "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", - total_discovered, url, connector_id, - ) - - for tool_def in tool_definitions: - try: - tool = await _create_mcp_tool_from_definition_http( - tool_def, - url, - headers, - connector_name=connector_name, - connector_id=connector_id, - trusted_tools=trusted_tools, - readonly_tools=readonly_tools, - tool_name_prefix=tool_name_prefix, - ) - tools.append(tool) - except Exception as e: - logger.exception( - "Failed to create HTTP tool '%s' from connector %d: %s", - tool_def.get("name"), connector_id, e, - ) - - except Exception as e: - logger.exception( - "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", - url, connector_id, e, + logger.warning( + "HTTP MCP discovery for connector %d got 401 — attempting token refresh", + connector_id, ) + fresh_headers = await _force_refresh_and_get_headers(connector_id) + if fresh_headers is None: + await _mark_connector_auth_expired(connector_id) + logger.error( + "HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired", + connector_id, + ) + return tools + + try: + tool_definitions = await _discover(fresh_headers) + headers = fresh_headers + logger.info( + "HTTP MCP discovery for connector %d succeeded after 401 recovery", + connector_id, + ) + except Exception as retry_err: + logger.exception( + "HTTP MCP discovery for connector %d still failing after refresh: %s", + connector_id, retry_err, + ) + if _is_auth_error(retry_err): + await _mark_connector_auth_expired(connector_id) + return tools + + total_discovered = len(tool_definitions) + + if allowed_set: + tool_definitions = [ + td for td in tool_definitions if td["name"] in allowed_set + ] + logger.info( + "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", + url, connector_id, len(tool_definitions), total_discovered, + ) + else: + logger.info( + "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", + total_discovered, url, connector_id, + ) + + for tool_def in tool_definitions: + try: + tool = await _create_mcp_tool_from_definition_http( + tool_def, + url, + headers, + connector_name=connector_name, + connector_id=connector_id, + trusted_tools=trusted_tools, + readonly_tools=readonly_tools, + tool_name_prefix=tool_name_prefix, + is_generic_mcp=is_generic_mcp, + ) + tools.append(tool) + except Exception as e: + logger.exception( + "Failed to create HTTP tool '%s' from connector %d: %s", + tool_def.get("name"), connector_id, e, + ) return tools @@ -476,6 +626,91 @@ def _inject_oauth_headers( return None +async def _refresh_connector_token( + session: AsyncSession, + connector: "SearchSourceConnector", +) -> str | None: + """Refresh the OAuth token for an MCP connector and persist the result. + + This is the shared core used by both proactive (pre-expiry) and reactive + (401 recovery) refresh paths. It handles: + - Decrypting the current refresh token / client secret + - Calling the token endpoint + - Encrypting and persisting the new tokens + - Clearing ``auth_expired`` if it was set + - Invalidating the MCP tools cache + + Returns the **plaintext** new access token on success, or ``None`` on + failure (no refresh token, IdP error, etc.). + """ + from datetime import UTC, datetime, timedelta + + from sqlalchemy.orm.attributes import flag_modified + + from app.services.mcp_oauth.discovery import refresh_access_token + + cfg = connector.config or {} + mcp_oauth = cfg.get("mcp_oauth", {}) + + refresh_token = mcp_oauth.get("refresh_token") + if not refresh_token: + logger.warning( + "MCP connector %s: no refresh_token available", + connector.id, + ) + return None + + enc = _get_token_enc() + decrypted_refresh = enc.decrypt_token(refresh_token) + decrypted_secret = ( + enc.decrypt_token(mcp_oauth["client_secret"]) + if mcp_oauth.get("client_secret") + else "" + ) + + token_json = await refresh_access_token( + token_endpoint=mcp_oauth["token_endpoint"], + refresh_token=decrypted_refresh, + client_id=mcp_oauth["client_id"], + client_secret=decrypted_secret, + ) + + new_access = token_json.get("access_token") + if not new_access: + logger.warning( + "MCP connector %s: token refresh returned no access_token", + connector.id, + ) + return None + + new_expires_at = None + if token_json.get("expires_in"): + new_expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + updated_oauth = dict(mcp_oauth) + updated_oauth["access_token"] = enc.encrypt_token(new_access) + if token_json.get("refresh_token"): + updated_oauth["refresh_token"] = enc.encrypt_token( + token_json["refresh_token"] + ) + updated_oauth["expires_at"] = ( + new_expires_at.isoformat() if new_expires_at else None + ) + + updated_cfg = {**cfg, "mcp_oauth": updated_oauth} + updated_cfg.pop("auth_expired", None) + connector.config = updated_cfg + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + + invalidate_mcp_tools_cache(connector.search_space_id) + + return new_access + + async def _maybe_refresh_mcp_oauth_token( session: AsyncSession, connector: "SearchSourceConnector", @@ -504,73 +739,13 @@ async def _maybe_refresh_mcp_oauth_token( except (ValueError, TypeError): return server_config - refresh_token = mcp_oauth.get("refresh_token") - if not refresh_token: - logger.warning( - "MCP connector %s token expired but no refresh_token available", - connector.id, - ) - return server_config - try: - from app.services.mcp_oauth.discovery import refresh_access_token - - enc = _get_token_enc() - decrypted_refresh = enc.decrypt_token(refresh_token) - decrypted_secret = ( - enc.decrypt_token(mcp_oauth["client_secret"]) - if mcp_oauth.get("client_secret") - else "" - ) - - token_json = await refresh_access_token( - token_endpoint=mcp_oauth["token_endpoint"], - refresh_token=decrypted_refresh, - client_id=mcp_oauth["client_id"], - client_secret=decrypted_secret, - ) - - new_access = token_json.get("access_token") + new_access = await _refresh_connector_token(session, connector) if not new_access: - logger.warning( - "MCP connector %s token refresh returned no access_token", - connector.id, - ) return server_config - new_expires_at = None - if token_json.get("expires_in"): - new_expires_at = datetime.now(UTC) + timedelta( - seconds=int(token_json["expires_in"]) - ) + logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id) - updated_oauth = dict(mcp_oauth) - updated_oauth["access_token"] = enc.encrypt_token(new_access) - if token_json.get("refresh_token"): - updated_oauth["refresh_token"] = enc.encrypt_token( - token_json["refresh_token"] - ) - updated_oauth["expires_at"] = ( - new_expires_at.isoformat() if new_expires_at else None - ) - - from sqlalchemy.orm.attributes import flag_modified - - connector.config = { - **cfg, - "server_config": server_config, - "mcp_oauth": updated_oauth, - } - flag_modified(connector, "config") - await session.commit() - await session.refresh(connector) - - logger.info("Refreshed MCP OAuth token for connector %s", connector.id) - - # Invalidate cache so next call picks up the new token. - invalidate_mcp_tools_cache(connector.search_space_id) - - # Return server_config with the fresh token injected for immediate use. refreshed_config = dict(server_config) refreshed_config["headers"] = { **server_config.get("headers", {}), @@ -587,6 +762,117 @@ async def _maybe_refresh_mcp_oauth_token( return server_config +# --------------------------------------------------------------------------- +# Reactive 401 handling helpers +# --------------------------------------------------------------------------- + + +def _is_auth_error(exc: Exception) -> bool: + """Check if an exception indicates an HTTP 401 authentication failure.""" + try: + import httpx + + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code == 401 + except ImportError: + pass + err_str = str(exc).lower() + return "401" in err_str or "unauthorized" in err_str + + +async def _force_refresh_and_get_headers( + connector_id: int, +) -> dict[str, str] | None: + """Force-refresh OAuth token for a connector and return fresh HTTP headers. + + Opens a **new** DB session so this can be called from inside tool closures + that don't have access to the original session. + + Returns ``None`` when the connector is not OAuth-backed, has no + refresh token, or the refresh itself fails. + """ + from app.db import async_session_maker + + try: + async with async_session_maker() as session: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + ) + ) + connector = result.scalars().first() + if not connector: + return None + + cfg = connector.config or {} + if not cfg.get("mcp_oauth"): + return None + + server_config = cfg.get("server_config", {}) + + new_access = await _refresh_connector_token(session, connector) + if not new_access: + return None + + logger.info( + "Force-refreshed MCP OAuth token for connector %s (401 recovery)", + connector_id, + ) + return { + **server_config.get("headers", {}), + "Authorization": f"Bearer {new_access}", + } + + except Exception: + logger.warning( + "Failed to force-refresh MCP OAuth token for connector %s", + connector_id, + exc_info=True, + ) + return None + + +async def _mark_connector_auth_expired(connector_id: int) -> None: + """Set ``config.auth_expired = True`` so the frontend shows re-auth UI.""" + from app.db import async_session_maker + + try: + async with async_session_maker() as session: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + ) + ) + connector = result.scalars().first() + if not connector: + return + + cfg = dict(connector.config or {}) + if cfg.get("auth_expired"): + return + + cfg["auth_expired"] = True + connector.config = cfg + + from sqlalchemy.orm.attributes import flag_modified + + flag_modified(connector, "config") + await session.commit() + + logger.info( + "Marked MCP connector %s as auth_expired after unrecoverable 401", + connector_id, + ) + invalidate_mcp_tools_cache(connector.search_space_id) + + except Exception: + logger.warning( + "Failed to mark connector %s as auth_expired", + connector_id, + exc_info=True, + ) + + def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: """Invalidate cached MCP tools. @@ -661,7 +947,7 @@ async def load_mcp_tools( multi_account_types, ) - tools: list[StructuredTool] = [] + discovery_tasks: list[dict[str, Any]] = [] for connector in connectors: try: cfg = connector.config or {} @@ -674,14 +960,10 @@ async def load_mcp_tools( ) continue - # For MCP OAuth connectors: refresh if needed, then decrypt the - # access token and inject it into headers at runtime. The DB - # intentionally does NOT store plaintext tokens in server_config. if cfg.get("mcp_oauth"): server_config = await _maybe_refresh_mcp_oauth_token( session, connector, cfg, server_config, ) - # Re-read cfg after potential refresh (connector was reloaded from DB). cfg = connector.config or {} server_config = _inject_oauth_headers(cfg, server_config) if server_config is None: @@ -689,6 +971,7 @@ async def load_mcp_tools( "Skipping MCP connector %d — OAuth token decryption failed", connector.id, ) + await _mark_connector_auth_expired(connector.id) continue trusted_tools = cfg.get("trusted_tools", []) @@ -703,7 +986,6 @@ async def load_mcp_tools( allowed_tools = svc_cfg.allowed_tools if svc_cfg else [] readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset() - # Build a prefix only when multiple accounts share the same type. tool_name_prefix: str | None = None if ct in multi_account_types and svc_cfg: service_key = next( @@ -713,34 +995,68 @@ async def load_mcp_tools( if service_key: tool_name_prefix = f"{service_key}_{connector.id}" - transport = server_config.get("transport", "stdio") - - if transport in ("streamable-http", "http", "sse"): - connector_tools = await _load_http_mcp_tools( - connector.id, - connector.name, - server_config, - trusted_tools=trusted_tools, - allowed_tools=allowed_tools, - readonly_tools=readonly_tools, - tool_name_prefix=tool_name_prefix, - ) - else: - connector_tools = await _load_stdio_mcp_tools( - connector.id, - connector.name, - server_config, - trusted_tools=trusted_tools, - ) - - tools.extend(connector_tools) + discovery_tasks.append({ + "connector_id": connector.id, + "connector_name": connector.name, + "server_config": server_config, + "trusted_tools": trusted_tools, + "allowed_tools": allowed_tools, + "readonly_tools": readonly_tools, + "tool_name_prefix": tool_name_prefix, + "transport": server_config.get("transport", "stdio"), + "is_generic_mcp": svc_cfg is None, + }) except Exception as e: logger.exception( - "Failed to load tools from MCP connector %d: %s", + "Failed to prepare MCP connector %d: %s", connector.id, e, ) + async def _discover_one(task: dict[str, Any]) -> list[StructuredTool]: + try: + if task["transport"] in ("streamable-http", "http", "sse"): + return await asyncio.wait_for( + _load_http_mcp_tools( + task["connector_id"], + task["connector_name"], + task["server_config"], + trusted_tools=task["trusted_tools"], + allowed_tools=task["allowed_tools"], + readonly_tools=task["readonly_tools"], + tool_name_prefix=task["tool_name_prefix"], + is_generic_mcp=task.get("is_generic_mcp", False), + ), + timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, + ) + else: + return await asyncio.wait_for( + _load_stdio_mcp_tools( + task["connector_id"], + task["connector_name"], + task["server_config"], + trusted_tools=task["trusted_tools"], + ), + timeout=_MCP_DISCOVERY_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + logger.error( + "MCP connector %d timed out after %ds during discovery", + task["connector_id"], _MCP_DISCOVERY_TIMEOUT_SECONDS, + ) + return [] + except Exception as e: + logger.exception( + "Failed to load tools from MCP connector %d: %s", + task["connector_id"], e, + ) + return [] + + results = await asyncio.gather(*[_discover_one(t) for t in discovery_tasks]) + tools: list[StructuredTool] = [ + tool for sublist in results for tool in sublist + ] + _mcp_tools_cache[search_space_id] = (now, tools) if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE: diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 989894003..b8142c192 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -3105,13 +3105,18 @@ async def trust_mcp_tool( """Add a tool to the MCP connector's trusted (always-allow) list. Once trusted, the tool executes without HITL approval on subsequent calls. + Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors + (LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``. """ try: + from sqlalchemy import cast + from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB + result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.user_id == user.id, + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 ) ) connector = result.scalars().first() @@ -3156,13 +3161,17 @@ async def untrust_mcp_tool( """Remove a tool from the MCP connector's trusted list. The tool will require HITL approval again on subsequent calls. + Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors. """ try: + from sqlalchemy import cast + from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB + result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.user_id == user.id, + cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601 ) ) connector = result.scalars().first() diff --git a/surfsense_backend/app/services/confluence/kb_sync_service.py b/surfsense_backend/app/services/confluence/kb_sync_service.py index f786a9920..cae2bef88 100644 --- a/surfsense_backend/app/services/confluence/kb_sync_service.py +++ b/surfsense_backend/app/services/confluence/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.confluence_history import ConfluenceHistoryConnector from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -66,6 +65,8 @@ class ConfluenceKBSyncService: if dup: content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -184,6 +185,8 @@ class ConfluenceKBSyncService: space_id = (document.document_metadata or {}).get("space_id", "") + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/dropbox/kb_sync_service.py b/surfsense_backend/app/services/dropbox/kb_sync_service.py index 2a74bdf4b..9d1951013 100644 --- a/surfsense_backend/app/services/dropbox/kb_sync_service.py +++ b/surfsense_backend/app/services/dropbox/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType from app.indexing_pipeline.document_hashing import compute_identifier_hash -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -73,6 +72,8 @@ class DropboxKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, diff --git a/surfsense_backend/app/services/gmail/kb_sync_service.py b/surfsense_backend/app/services/gmail/kb_sync_service.py index b3b50d305..885ee4b94 100644 --- a/surfsense_backend/app/services/gmail/kb_sync_service.py +++ b/surfsense_backend/app/services/gmail/kb_sync_service.py @@ -4,7 +4,6 @@ from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -78,6 +77,8 @@ class GmailKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, diff --git a/surfsense_backend/app/services/google_calendar/kb_sync_service.py b/surfsense_backend/app/services/google_calendar/kb_sync_service.py index 3cda02b9b..20426f3bc 100644 --- a/surfsense_backend/app/services/google_calendar/kb_sync_service.py +++ b/surfsense_backend/app/services/google_calendar/kb_sync_service.py @@ -14,7 +14,6 @@ from app.db import ( SearchSourceConnector, SearchSourceConnectorType, ) -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -91,6 +90,8 @@ class GoogleCalendarKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -249,6 +250,8 @@ class GoogleCalendarKBSyncService: if not indexable_content: return {"status": "error", "message": "Event produced empty content"} + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/google_drive/kb_sync_service.py b/surfsense_backend/app/services/google_drive/kb_sync_service.py index 92a39f7b9..0a8eb47a6 100644 --- a/surfsense_backend/app/services/google_drive/kb_sync_service.py +++ b/surfsense_backend/app/services/google_drive/kb_sync_service.py @@ -4,7 +4,6 @@ from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -75,6 +74,8 @@ class GoogleDriveKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, diff --git a/surfsense_backend/app/services/jira/kb_sync_service.py b/surfsense_backend/app/services/jira/kb_sync_service.py index 4d2a66e52..8e88bee81 100644 --- a/surfsense_backend/app/services/jira/kb_sync_service.py +++ b/surfsense_backend/app/services/jira/kb_sync_service.py @@ -6,7 +6,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.jira_history import JiraHistoryConnector from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -75,6 +74,8 @@ class JiraKBSyncService: if dup: content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -190,6 +191,8 @@ class JiraKBSyncService: state = formatted.get("status", "Unknown") comment_count = len(formatted.get("comments", [])) + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/linear/kb_sync_service.py b/surfsense_backend/app/services/linear/kb_sync_service.py index dab42af55..471227602 100644 --- a/surfsense_backend/app/services/linear/kb_sync_service.py +++ b/surfsense_backend/app/services/linear/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.connectors.linear_connector import LinearConnector from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -85,6 +84,8 @@ class LinearKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -226,6 +227,8 @@ class LinearKBSyncService: comment_count = len(formatted_issue.get("comments", [])) formatted_issue.get("description", "") + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, search_space_id, disable_streaming=True ) diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 79a72dd25..942a9b7af 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -7,7 +7,6 @@ from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.agents.new_chat.llm_config import SanitizedChatLiteLLM from app.config import config from app.db import NewLLMConfig, SearchSpace from app.services.llm_router_service import ( @@ -204,6 +203,8 @@ async def validate_llm_config( if litellm_params: litellm_kwargs.update(litellm_params) + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + llm = SanitizedChatLiteLLM(**litellm_kwargs) # Run the test call in a worker thread with a hard timeout. Some @@ -377,6 +378,8 @@ async def get_search_space_llm_instance( if disable_streaming: litellm_kwargs["disable_streaming"] = True + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) # Get the LLM configuration from database (NewLLMConfig) @@ -454,6 +457,8 @@ async def get_search_space_llm_instance( if disable_streaming: litellm_kwargs["disable_streaming"] = True + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) except Exception as e: @@ -555,6 +560,8 @@ async def get_vision_llm( if global_cfg.get("litellm_params"): litellm_kwargs.update(global_cfg["litellm_params"]) + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) result = await session.execute( @@ -588,6 +595,8 @@ async def get_vision_llm( if vision_cfg.litellm_params: litellm_kwargs.update(vision_cfg.litellm_params) + from app.agents.new_chat.llm_config import SanitizedChatLiteLLM + return SanitizedChatLiteLLM(**litellm_kwargs) except Exception as e: diff --git a/surfsense_backend/app/services/notion/kb_sync_service.py b/surfsense_backend/app/services/notion/kb_sync_service.py index be177c7ca..b10d1b157 100644 --- a/surfsense_backend/app/services/notion/kb_sync_service.py +++ b/surfsense_backend/app/services/notion/kb_sync_service.py @@ -4,7 +4,6 @@ from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -74,6 +73,8 @@ class NotionKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, @@ -244,6 +245,8 @@ class NotionKBSyncService: f"Final content length: {len(full_content)} chars, verified={content_verified}" ) + from app.services.llm_service import get_user_long_context_llm + logger.debug("Generating summary and embeddings") user_llm = await get_user_long_context_llm( self.db_session, diff --git a/surfsense_backend/app/services/onedrive/kb_sync_service.py b/surfsense_backend/app/services/onedrive/kb_sync_service.py index 962c19fc9..e9b2e38ea 100644 --- a/surfsense_backend/app/services/onedrive/kb_sync_service.py +++ b/surfsense_backend/app/services/onedrive/kb_sync_service.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import Document, DocumentType from app.indexing_pipeline.document_hashing import compute_identifier_hash -from app.services.llm_service import get_user_long_context_llm from app.utils.document_converters import ( create_document_chunks, embed_text, @@ -73,6 +72,8 @@ class OneDriveKBSyncService: ) content_hash = unique_hash + from app.services.llm_service import get_user_long_context_llm + user_llm = await get_user_long_context_llm( self.db_session, user_id, diff --git a/surfsense_web/components/assistant-ui/connector-popup.tsx b/surfsense_web/components/assistant-ui/connector-popup.tsx index 84361e25b..66333a9ef 100644 --- a/surfsense_web/components/assistant-ui/connector-popup.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup.tsx @@ -123,8 +123,9 @@ export const ConnectorIndicator = forwardRef ) : viewingMCPList ? ( - + handleDisconnectFromList(connector, () => refreshConnectors())} + onAddAccount={handleAddNewMCPFromList} + addButtonText="Add New MCP Server" + /> ) : viewingAccountsType ? ( - { + handleDisconnectFromList(connector, () => refreshConnectors())} + onAddAccount={() => { // Check both OAUTH_CONNECTORS and COMPOSIO_CONNECTORS const oauthConnector = OAUTH_CONNECTORS.find( diff --git a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx index 58d365128..fc9812240 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connect-forms/components/mcp-connect-form.tsx @@ -1,6 +1,6 @@ "use client"; -import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react"; +import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react"; import { type FC, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; @@ -212,7 +212,14 @@ export const MCPConnectForm: FC = ({ onSubmit, isSubmitting }) variant="secondary" className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > - {isTesting ? "Testing Connection" : "Test Connection"} + {isTesting ? ( + <> + + Testing Connection... + + ) : ( + "Test Connection" + )} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx index ca997a9ba..d6f60e824 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-config.tsx @@ -1,6 +1,6 @@ "use client"; -import { CheckCircle2, ChevronDown, ChevronUp, Server, XCircle } from "lucide-react"; +import { CheckCircle2, ChevronDown, ChevronUp, Loader2, Server, XCircle } from "lucide-react"; import type { FC } from "react"; import { useCallback, useEffect, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; @@ -217,7 +217,14 @@ export const MCPConfig: FC = ({ connector, onConfigChange, onNam variant="secondary" className="w-full h-8 text-[13px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" > - {isTesting ? "Testing Connection" : "Test Connection"} + {isTesting ? ( + <> + + Testing Connection... + + ) : ( + "Test Connection" + )} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index a69cf968f..44461c351 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -7,7 +7,6 @@ import { toast } from "sonner"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; -import { EnumConnectorName } from "@/contracts/enums/connector"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; @@ -16,23 +15,11 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; -import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants"; +import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; +import { MCPServiceConfig } from "../components/mcp-service-config"; import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index"; -const REAUTH_ENDPOINTS: Partial> = { - [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", - [EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth", - [EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth", - [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth", - [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth", - [EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth", -}; - interface ConnectorEditViewProps { connector: SearchSourceConnector; startDate: Date | undefined; @@ -86,7 +73,7 @@ export const ConnectorEditView: FC = ({ }) => { const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom); const isAuthExpired = connector.config?.auth_expired === true; - const reauthEndpoint = REAUTH_ENDPOINTS[connector.connector_type]; + const reauthEndpoint = getReauthEndpoint(connector); const [reauthing, setReauthing] = useState(false); const handleReauth = useCallback(async () => { @@ -124,10 +111,7 @@ export const ConnectorEditView: FC = ({ // Get connector-specific config component (MCP-backed connectors use a generic view) const ConnectorConfigComponent = useMemo(() => { - if (isMCPBacked) { - const { MCPServiceConfig } = require("../components/mcp-service-config"); - return MCPServiceConfig as FC; - } + if (isMCPBacked) return MCPServiceConfig; return getConnectorConfigComponent(connector.connector_type); }, [connector.connector_type, isMCPBacked]); const [isScrolled, setIsScrolled] = useState(false); diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 05f866d0f..2ee811c19 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -1,4 +1,5 @@ import { EnumConnectorName } from "@/contracts/enums/connector"; +import type { SearchSourceConnector } from "@/contracts/types/connector.types"; /** * Connectors that operate in real time (no background indexing). @@ -367,5 +368,45 @@ export function getConnectorTelemetryMeta(connectorType: string): ConnectorTelem }; } +// ============================================================================= +// REAUTH ENDPOINTS +// ============================================================================= + +/** + * Legacy (non-MCP) OAuth reauth endpoints, keyed by connector type. + * These are used for connectors that were NOT created via MCP OAuth. + */ +export const LEGACY_REAUTH_ENDPOINTS: Partial> = { + [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", + [EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth", + [EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth", + [EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth", + [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth", + [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth", + [EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", + [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", + [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", + [EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth", + [EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth", + [EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth", + [EnumConnectorName.TEAMS_CONNECTOR]: "/api/v1/auth/teams/connector/reauth", + [EnumConnectorName.DISCORD_CONNECTOR]: "/api/v1/auth/discord/connector/reauth", +}; + +/** + * Resolve the reauth endpoint for a connector. + * + * MCP OAuth connectors (those with ``config.mcp_service``) dynamically build + * the URL from the service key. Legacy OAuth connectors fall back to the + * static ``LEGACY_REAUTH_ENDPOINTS`` map. + */ +export function getReauthEndpoint(connector: SearchSourceConnector): string | undefined { + const mcpService = connector.config?.mcp_service as string | undefined; + if (mcpService) { + return `/api/v1/auth/mcp/${mcpService}/connector/reauth`; + } + return LEGACY_REAUTH_ENDPOINTS[connector.connector_type]; +} + // Re-export IndexingConfigState from schemas for backward compatibility export type { IndexingConfigState } from "./connector-popup.schemas"; diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index a8d395e5c..a9223fee5 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -1311,6 +1311,25 @@ export const useConnectorDialog = () => { [editingConnector, searchSpaceId, deleteConnector, cameFromMCPList, setIsOpen] ); + const handleDisconnectFromList = useCallback( + async (connector: SearchSourceConnector, refreshConnectors: () => void) => { + if (!searchSpaceId) return; + try { + await deleteConnector({ id: connector.id }); + trackConnectorDeleted(Number(searchSpaceId), connector.connector_type, connector.id); + toast.success(`${connector.name} disconnected successfully`); + refreshConnectors(); + queryClient.invalidateQueries({ + queryKey: cacheKeys.logs.summary(Number(searchSpaceId)), + }); + } catch (error) { + console.error("Error disconnecting connector:", error); + toast.error("Failed to disconnect connector"); + } + }, + [searchSpaceId, deleteConnector] + ); + // Handle quick index (index with selected date range, or backend defaults if none selected) const handleQuickIndexConnector = useCallback( async ( @@ -1484,6 +1503,7 @@ export const useConnectorDialog = () => { handleStartEdit, handleSaveConnector, handleDisconnectConnector, + handleDisconnectFromList, handleBackFromEdit, handleBackFromConnect, handleBackFromYouTube, diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index b48b14ed2..b3c087599 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtomValue } from "jotai"; -import { ArrowLeft, Plus, RefreshCw, Server } from "lucide-react"; +import { ArrowLeft, Plus, RefreshCw, Server, Trash2 } from "lucide-react"; import { type FC, useCallback, useState } from "react"; import { toast } from "sonner"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; @@ -13,25 +13,10 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; -import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; +import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants"; import { useConnectorStatus } from "../hooks/use-connector-status"; import { getConnectorDisplayName } from "../tabs/all-connectors-tab"; -const REAUTH_ENDPOINTS: Partial> = { - [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", - [EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth", - [EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth", - [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth", - [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth", - [EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth", - [EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth", - [EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth", - [EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth", -}; - interface ConnectorAccountsListViewProps { connectorType: string; connectorTitle: string; @@ -39,15 +24,12 @@ interface ConnectorAccountsListViewProps { indexingConnectorIds: Set; onBack: () => void; onManage: (connector: SearchSourceConnector) => void; + onDisconnect?: (connector: SearchSourceConnector) => Promise | void; onAddAccount: () => void; isConnecting?: boolean; addButtonText?: string; } -function isLiveConnector(connectorType: string): boolean { - return LIVE_CONNECTOR_TYPES.has(connectorType) || connectorType === "MCP_CONNECTOR"; -} - export const ConnectorAccountsListView: FC = ({ connectorType, connectorTitle, @@ -55,12 +37,15 @@ export const ConnectorAccountsListView: FC = ({ indexingConnectorIds, onBack, onManage, + onDisconnect, onAddAccount, isConnecting = false, addButtonText, }) => { const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const [reauthingId, setReauthingId] = useState(null); + const [confirmDisconnectId, setConfirmDisconnectId] = useState(null); + const [disconnectingId, setDisconnectingId] = useState(null); // Get connector status const { isConnectorEnabled, getConnectorStatusMessage } = useConnectorStatus(); @@ -68,16 +53,15 @@ export const ConnectorAccountsListView: FC = ({ const isEnabled = isConnectorEnabled(connectorType); const statusMessage = getConnectorStatusMessage(connectorType); - const reauthEndpoint = REAUTH_ENDPOINTS[connectorType]; - const handleReauth = useCallback( - async (connectorId: number) => { - if (!searchSpaceId || !reauthEndpoint) return; - setReauthingId(connectorId); + async (connector: SearchSourceConnector) => { + const endpoint = getReauthEndpoint(connector); + if (!searchSpaceId || !endpoint) return; + setReauthingId(connector.id); try { const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const url = new URL(`${backendUrl}${reauthEndpoint}`); - url.searchParams.set("connector_id", String(connectorId)); + const url = new URL(`${backendUrl}${endpoint}`); + url.searchParams.set("connector_id", String(connector.id)); url.searchParams.set("space_id", String(searchSpaceId)); url.searchParams.set("return_url", window.location.pathname); const response = await authenticatedFetch(url.toString()); @@ -99,7 +83,7 @@ export const ConnectorAccountsListView: FC = ({ setReauthingId(null); } }, - [searchSpaceId, reauthEndpoint] + [searchSpaceId] ); // Filter connectors to only show those of this type @@ -198,9 +182,11 @@ export const ConnectorAccountsListView: FC = ({ ) : (
- {typeConnectors.map((connector) => { - const isIndexing = indexingConnectorIds.has(connector.id); - const isAuthExpired = !!reauthEndpoint && connector.config?.auth_expired === true; + {typeConnectors.map((connector) => { + const isIndexing = indexingConnectorIds.has(connector.id); + const connectorReauthEndpoint = getReauthEndpoint(connector); + const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true; + const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type) || Boolean(connector.config?.server_config); return (
= ({ Syncing

- ) : !isLiveConnector(connector.connector_type) ? ( + ) : !isLive ? (

{connector.last_indexed_at ? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}` @@ -239,28 +225,73 @@ export const ConnectorAccountsListView: FC = ({

) : null}
- {isAuthExpired ? ( - + {isAuthExpired ? ( + + ) : isLive && onDisconnect ? ( + confirmDisconnectId === connector.id ? ( +
+ + +
) : ( - )} + ) + ) : ( + + )}
); })} diff --git a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx index 809b76c38..d21f249ee 100644 --- a/surfsense_web/components/tool-ui/generic-hitl-approval.tsx +++ b/surfsense_web/components/tool-ui/generic-hitl-approval.tsx @@ -3,6 +3,7 @@ import type { ToolCallMessagePartComponent } from "@assistant-ui/react"; import { CornerDownLeftIcon, Pen } from "lucide-react"; import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; import { TextShimmerLoader } from "@/components/prompt-kit/loader"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -116,8 +117,8 @@ function GenericApprovalCard({ if (phase !== "pending" || !isMCPTool) return; setProcessing(); onDecision({ type: "approve" }); - connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch((err) => { - console.error("Failed to trust MCP tool:", err); + connectorsApiService.trustMCPTool(mcpConnectorId, toolName).catch(() => { + toast.error("Failed to save 'Always Allow' preference. The tool will still require approval next time."); }); }, [phase, setProcessing, onDecision, isMCPTool, mcpConnectorId, toolName]); diff --git a/surfsense_web/lib/apis/connectors-api.service.ts b/surfsense_web/lib/apis/connectors-api.service.ts index 3eaa767c5..f4137c787 100644 --- a/surfsense_web/lib/apis/connectors-api.service.ts +++ b/surfsense_web/lib/apis/connectors-api.service.ts @@ -414,16 +414,8 @@ class ConnectorsApiService { * Subsequent calls to this tool will skip HITL approval. */ trustMCPTool = async (connectorId: number, toolName: string): Promise => { - const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const token = - typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined; - await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/trust-tool`, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...(token ? { Authorization: `Bearer ${token}` } : {}), - }, - body: JSON.stringify({ tool_name: toolName }), + await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/trust-tool`, undefined, { + body: { tool_name: toolName }, }); }; @@ -431,16 +423,8 @@ class ConnectorsApiService { * Remove a tool from the MCP connector's "Always Allow" list. */ untrustMCPTool = async (connectorId: number, toolName: string): Promise => { - const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000"; - const token = - typeof window !== "undefined" ? document.cookie.match(/fapiToken=([^;]+)/)?.[1] : undefined; - await fetch(`${backendUrl}/api/v1/connectors/mcp/${connectorId}/untrust-tool`, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...(token ? { Authorization: `Bearer ${token}` } : {}), - }, - body: JSON.stringify({ tool_name: toolName }), + await baseApiService.post(`/api/v1/connectors/mcp/${connectorId}/untrust-tool`, undefined, { + body: { tool_name: toolName }, }); }; }