mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-15 18:25:18 +02:00
feat: add MCP connector backend support
This commit is contained in:
parent
8646fecc8b
commit
305a981d14
14 changed files with 1083 additions and 29 deletions
185
surfsense_backend/app/agents/new_chat/tools/mcp_client.py
Normal file
185
surfsense_backend/app/agents/new_chat/tools/mcp_client.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""MCP Client Wrapper.
|
||||
|
||||
This module provides a client for communicating with MCP servers via stdio transport.
|
||||
It handles server lifecycle management, tool discovery, and tool execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Client for communicating with an MCP server."""
|
||||
|
||||
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
|
||||
"""Initialize MCP client.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server (e.g., "uvx", "node")
|
||||
args: Arguments for the command (e.g., ["mcp-server-git"])
|
||||
env: Optional environment variables for the server process
|
||||
|
||||
"""
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.env = env or {}
|
||||
self.session: ClientSession | None = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
"""Connect to the MCP server and manage its lifecycle.
|
||||
|
||||
Yields:
|
||||
ClientSession: Active MCP session for making requests
|
||||
|
||||
"""
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command,
|
||||
args=self.args,
|
||||
env=server_env
|
||||
)
|
||||
|
||||
# Spawn server process and create session
|
||||
async with stdio_client(server=server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
logger.info(
|
||||
f"Connected to MCP server: {self.command} {' '.join(self.args)}"
|
||||
)
|
||||
yield session
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MCP server: {e!s}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.session = None
|
||||
logger.info(f"Disconnected from MCP server: {self.command}")
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]]:
|
||||
"""List all tools available from the MCP server.
|
||||
|
||||
Returns:
|
||||
List of tool definitions with name, description, and input schema
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not connected to server
|
||||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||
|
||||
try:
|
||||
# Call tools/list RPC method
|
||||
response = await self.session.list_tools()
|
||||
|
||||
tools = []
|
||||
for tool in response.tools:
|
||||
tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||
})
|
||||
|
||||
logger.info(f"Listed {len(tools)} tools from MCP server")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list tools from MCP server: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not connected to server
|
||||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||
|
||||
try:
|
||||
logger.info(f"Calling MCP tool '{tool_name}' with arguments: {arguments}")
|
||||
|
||||
# Call tools/call RPC method
|
||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
result.append(str(content.data))
|
||||
else:
|
||||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info(f"MCP tool '{tool_name}' succeeded: {result_str[:200]}")
|
||||
return result_str
|
||||
|
||||
except RuntimeError as e:
|
||||
# Handle validation errors from MCP server responses
|
||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
||||
if "Invalid structured content" in str(e):
|
||||
logger.warning(f"MCP server returned data not matching its schema, but continuing: {e}")
|
||||
# Try to extract result from error message or return a success message
|
||||
return "Operation completed (server returned unexpected format)"
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to call MCP tool '{tool_name}': {e!s}", exc_info=True)
|
||||
return f"Error calling tool: {e!s}"
|
||||
|
||||
|
||||
async def test_mcp_connection(
|
||||
command: str, args: list[str], env: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server and fetch available tools.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server
|
||||
args: Arguments for the command
|
||||
env: Optional environment variables
|
||||
|
||||
Returns:
|
||||
Dict with connection status and available tools
|
||||
|
||||
"""
|
||||
client = MCPClient(command, args, env)
|
||||
|
||||
try:
|
||||
async with client.connect():
|
||||
tools = await client.list_tools()
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Connected successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
250
surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
Normal file
250
surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""MCP Tool Factory.
|
||||
|
||||
This module creates LangChain tools from MCP servers using the Model Context Protocol.
|
||||
Tools are dynamically discovered from MCP servers - no manual configuration needed.
|
||||
|
||||
This implements real MCP protocol support similar to Cursor's implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, create_model
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_gemini_params(params: dict[str, Any], mcp_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalize Gemini-transformed parameter names back to MCP schema format.
|
||||
|
||||
Gemini tends to transform field names like:
|
||||
- entityType -> type
|
||||
- from/to -> fromEntity/toEntity
|
||||
- relationType -> relation
|
||||
|
||||
This function maps them back to the original MCP schema field names.
|
||||
"""
|
||||
schema_properties = mcp_schema.get("properties", {})
|
||||
normalized = {}
|
||||
|
||||
for param_key, param_value in params.items():
|
||||
# Handle array parameters (need to normalize nested objects)
|
||||
if isinstance(param_value, list) and len(param_value) > 0:
|
||||
if isinstance(param_value[0], dict):
|
||||
# Get the items schema to know what fields should be present
|
||||
items_schema = schema_properties.get(param_key, {}).get("items", {})
|
||||
items_properties = items_schema.get("properties", {})
|
||||
|
||||
normalized_array = []
|
||||
for item in param_value:
|
||||
normalized_item = {}
|
||||
for item_key, item_value in item.items():
|
||||
# Map common Gemini transformations back to MCP names
|
||||
if item_key == "type" and "entityType" in items_properties:
|
||||
normalized_item["entityType"] = item_value
|
||||
elif item_key == "fromEntity" and "from" in items_properties:
|
||||
normalized_item["from"] = item_value
|
||||
elif item_key == "toEntity" and "to" in items_properties:
|
||||
normalized_item["to"] = item_value
|
||||
elif item_key == "relation" and "relationType" in items_properties:
|
||||
normalized_item["relationType"] = item_value
|
||||
else:
|
||||
# Use the original key if it exists in schema
|
||||
normalized_item[item_key] = item_value
|
||||
|
||||
# Add missing required fields with empty defaults if needed
|
||||
for required_field in items_properties.keys():
|
||||
if required_field not in normalized_item:
|
||||
# For arrays like observations, default to empty array
|
||||
if items_properties[required_field].get("type") == "array":
|
||||
normalized_item[required_field] = []
|
||||
else:
|
||||
normalized_item[required_field] = ""
|
||||
|
||||
normalized_array.append(normalized_item)
|
||||
normalized[param_key] = normalized_array
|
||||
else:
|
||||
normalized[param_key] = param_value
|
||||
else:
|
||||
normalized[param_key] = param_value
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _create_dynamic_input_model_from_schema(
|
||||
tool_name: str, input_schema: dict[str, Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (used for model class name)
|
||||
input_schema: JSON schema from MCP server
|
||||
|
||||
Returns:
|
||||
Pydantic model class for tool input validation
|
||||
|
||||
"""
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
# Build Pydantic field definitions
|
||||
field_definitions = {}
|
||||
for param_name, param_schema in properties.items():
|
||||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
# Use Any type for complex schemas to preserve structure
|
||||
# This allows the MCP server to do its own validation
|
||||
from typing import Any as AnyType
|
||||
from pydantic import Field
|
||||
|
||||
if is_required:
|
||||
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
|
||||
else:
|
||||
field_definitions[param_name] = (
|
||||
AnyType | None,
|
||||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
# Create dynamic model
|
||||
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
||||
async def _create_mcp_tool_from_definition(
|
||||
tool_def: dict[str, Any],
|
||||
mcp_client: MCPClient,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition.
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
mcp_client: MCP client instance for calling the tool
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
# Normalize Gemini-transformed field names back to MCP schema
|
||||
# Gemini transforms: entityType->type, from/to->fromEntity/toEntity, relationType->relation
|
||||
normalized_kwargs = _normalize_gemini_params(kwargs, input_schema)
|
||||
|
||||
try:
|
||||
# Connect to server and call tool
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, normalized_kwargs)
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool with response_format to preserve exact schema
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_tool_call,
|
||||
args_schema=input_model,
|
||||
# Store the original MCP schema as metadata so we can access it later
|
||||
metadata={"mcp_input_schema": input_schema},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool: '{tool_name}'")
|
||||
return tool
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession, search_space_id: int,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load all MCP tools from user's active MCP server connectors.
|
||||
|
||||
This discovers tools dynamically from MCP servers using the protocol.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
search_space_id: User's search space ID
|
||||
|
||||
Returns:
|
||||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
try:
|
||||
# Fetch all ACTIVE MCP connectors for this search space
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.is_active == True, # Only load active connectors
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
try:
|
||||
# Extract server config
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
|
||||
command = server_config.get("command")
|
||||
args = server_config.get("args", [])
|
||||
env = server_config.get("env", {})
|
||||
|
||||
if not command:
|
||||
logger.warning(f"MCP connector {connector.id} missing command, skipping")
|
||||
continue
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from MCP server "
|
||||
f"'{command}' (connector {connector.id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector.id}: {e!s}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}",
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load MCP tools: {e!s}")
|
||||
return []
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Tools registry for SurfSense deep agent.
|
||||
"""Tools registry for SurfSense deep agent.
|
||||
|
||||
This module provides a registry pattern for managing tools in the SurfSense agent.
|
||||
It makes it easy for OSS contributors to add new tools by:
|
||||
|
|
@ -37,6 +36,7 @@ Example of adding a new tool:
|
|||
),
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
|
@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool
|
|||
from .display_image import create_display_image_tool
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .mcp_tool import load_mcp_tools
|
||||
from .podcast import create_generate_podcast_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
|
|
@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
|
|||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""
|
||||
Definition of a tool that can be added to the agent.
|
||||
"""Definition of a tool that can be added to the agent.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the tool
|
||||
|
|
@ -66,6 +66,7 @@ class ToolDefinition:
|
|||
factory: Callable that creates the tool. Receives a dict of dependencies.
|
||||
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
||||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
|
@ -178,8 +179,7 @@ def build_tools(
|
|||
disabled_tools: list[str] | None = None,
|
||||
additional_tools: list[BaseTool] | None = None,
|
||||
) -> list[BaseTool]:
|
||||
"""
|
||||
Build the list of tools for the agent.
|
||||
"""Build the list of tools for the agent.
|
||||
|
||||
Args:
|
||||
dependencies: Dict containing all possible dependencies:
|
||||
|
|
@ -206,6 +206,7 @@ def build_tools(
|
|||
|
||||
# Add custom tools
|
||||
tools = build_tools(deps, additional_tools=[my_custom_tool])
|
||||
|
||||
"""
|
||||
# Determine which tools to enable
|
||||
if enabled_tools is not None:
|
||||
|
|
@ -226,8 +227,9 @@ def build_tools(
|
|||
# Check that all required dependencies are provided
|
||||
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
|
||||
if missing_deps:
|
||||
msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
|
||||
raise ValueError(
|
||||
f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Create the tool
|
||||
|
|
@ -239,3 +241,61 @@ def build_tools(
|
|||
tools.extend(additional_tools)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def build_tools_async(
|
||||
dependencies: dict[str, Any],
|
||||
enabled_tools: list[str] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
additional_tools: list[BaseTool] | None = None,
|
||||
include_mcp_tools: bool = True,
|
||||
) -> list[BaseTool]:
|
||||
"""Async version of build_tools that also loads MCP tools from database.
|
||||
|
||||
Design Note:
|
||||
This function exists because MCP tools require database queries to load user configs,
|
||||
while built-in tools are created synchronously from static code.
|
||||
|
||||
Alternative: We could make build_tools() itself async and always query the database,
|
||||
but that would force async everywhere even when only using built-in tools. The current
|
||||
design keeps the simple case (static tools only) synchronous while supporting dynamic
|
||||
database-loaded tools through this async wrapper.
|
||||
|
||||
Args:
|
||||
dependencies: Dict containing all possible dependencies
|
||||
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
|
||||
disabled_tools: List of tool names to disable (applied after enabled_tools).
|
||||
additional_tools: Extra tools to add (e.g., custom tools not in registry).
|
||||
include_mcp_tools: Whether to load user's MCP tools from database.
|
||||
|
||||
Returns:
|
||||
List of configured tool instances ready for the agent, including MCP tools.
|
||||
|
||||
"""
|
||||
# Build standard tools
|
||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
||||
|
||||
# Load MCP tools if requested and dependencies are available
|
||||
if (
|
||||
include_mcp_tools
|
||||
and "db_session" in dependencies
|
||||
and "search_space_id" in dependencies
|
||||
):
|
||||
try:
|
||||
mcp_tools = await load_mcp_tools(
|
||||
dependencies["db_session"], dependencies["search_space_id"],
|
||||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail - just continue without MCP tools
|
||||
logging.exception(f"Failed to load MCP tools: {e!s}")
|
||||
|
||||
# Log all tools being returned to agent
|
||||
logging.info(
|
||||
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue