feat: add MCP connector backend support

This commit is contained in:
Manoj Aggarwal 2026-01-13 13:46:01 -08:00
parent 8646fecc8b
commit 305a981d14
14 changed files with 1083 additions and 29 deletions

View file

@ -0,0 +1,185 @@
"""MCP Client Wrapper.
This module provides a client for communicating with MCP servers via stdio transport.
It handles server lifecycle management, tool discovery, and tool execution.
"""
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
logger = logging.getLogger(__name__)
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
"""Initialize MCP client.
Args:
command: Command to spawn the MCP server (e.g., "uvx", "node")
args: Arguments for the command (e.g., ["mcp-server-git"])
env: Optional environment variables for the server process
"""
self.command = command
self.args = args
self.env = env or {}
self.session: ClientSession | None = None
@asynccontextmanager
async def connect(self):
"""Connect to the MCP server and manage its lifecycle.
Yields:
ClientSession: Active MCP session for making requests
"""
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command,
args=self.args,
env=server_env
)
# Spawn server process and create session
async with stdio_client(server=server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
logger.info(
f"Connected to MCP server: {self.command} {' '.join(self.args)}"
)
yield session
except Exception as e:
logger.error(f"Failed to connect to MCP server: {e!s}", exc_info=True)
raise
finally:
self.session = None
logger.info(f"Disconnected from MCP server: {self.command}")
async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server.
Returns:
List of tool definitions with name, description, and input schema
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
try:
# Call tools/list RPC method
response = await self.session.list_tools()
tools = []
for tool in response.tools:
tools.append({
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
})
logger.info(f"Listed {len(tools)} tools from MCP server")
return tools
except Exception as e:
logger.error(f"Failed to list tools from MCP server: {e!s}", exc_info=True)
raise
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
Returns:
Tool execution result
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
try:
logger.info(f"Calling MCP tool '{tool_name}' with arguments: {arguments}")
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.info(f"MCP tool '{tool_name}' succeeded: {result_str[:200]}")
return result_str
except RuntimeError as e:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning(f"MCP server returned data not matching its schema, but continuing: {e}")
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except Exception as e:
logger.error(f"Failed to call MCP tool '{tool_name}': {e!s}", exc_info=True)
return f"Error calling tool: {e!s}"
async def test_mcp_connection(
command: str, args: list[str], env: dict[str, str] | None = None
) -> dict[str, Any]:
"""Test connection to an MCP server and fetch available tools.
Args:
command: Command to spawn the MCP server
args: Arguments for the command
env: Optional environment variables
Returns:
Dict with connection status and available tools
"""
client = MCPClient(command, args, env)
try:
async with client.connect():
tools = await client.list_tools()
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}

View file

@ -0,0 +1,250 @@
"""MCP Tool Factory.
This module creates LangChain tools from MCP servers using the Model Context Protocol.
Tools are dynamically discovered from MCP servers - no manual configuration needed.
This implements real MCP protocol support similar to Cursor's implementation.
"""
import logging
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, create_model
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
def _normalize_gemini_params(params: dict[str, Any], mcp_schema: dict[str, Any]) -> dict[str, Any]:
"""Normalize Gemini-transformed parameter names back to MCP schema format.
Gemini tends to transform field names like:
- entityType -> type
- from/to -> fromEntity/toEntity
- relationType -> relation
This function maps them back to the original MCP schema field names.
"""
schema_properties = mcp_schema.get("properties", {})
normalized = {}
for param_key, param_value in params.items():
# Handle array parameters (need to normalize nested objects)
if isinstance(param_value, list) and len(param_value) > 0:
if isinstance(param_value[0], dict):
# Get the items schema to know what fields should be present
items_schema = schema_properties.get(param_key, {}).get("items", {})
items_properties = items_schema.get("properties", {})
normalized_array = []
for item in param_value:
normalized_item = {}
for item_key, item_value in item.items():
# Map common Gemini transformations back to MCP names
if item_key == "type" and "entityType" in items_properties:
normalized_item["entityType"] = item_value
elif item_key == "fromEntity" and "from" in items_properties:
normalized_item["from"] = item_value
elif item_key == "toEntity" and "to" in items_properties:
normalized_item["to"] = item_value
elif item_key == "relation" and "relationType" in items_properties:
normalized_item["relationType"] = item_value
else:
# Use the original key if it exists in schema
normalized_item[item_key] = item_value
# Add missing required fields with empty defaults if needed
for required_field in items_properties.keys():
if required_field not in normalized_item:
# For arrays like observations, default to empty array
if items_properties[required_field].get("type") == "array":
normalized_item[required_field] = []
else:
normalized_item[required_field] = ""
normalized_array.append(normalized_item)
normalized[param_key] = normalized_array
else:
normalized[param_key] = param_value
else:
normalized[param_key] = param_value
return normalized
def _create_dynamic_input_model_from_schema(
tool_name: str, input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema.
Args:
tool_name: Name of the tool (used for model class name)
input_schema: JSON schema from MCP server
Returns:
Pydantic model class for tool input validation
"""
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
# Build Pydantic field definitions
field_definitions = {}
for param_name, param_schema in properties.items():
param_description = param_schema.get("description", "")
is_required = param_name in required_fields
# Use Any type for complex schemas to preserve structure
# This allows the MCP server to do its own validation
from typing import Any as AnyType
from pydantic import Field
if is_required:
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
else:
field_definitions[param_name] = (
AnyType | None,
Field(None, description=param_description),
)
# Create dynamic model
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
return create_model(model_name, **field_definitions)
async def _create_mcp_tool_from_definition(
tool_def: dict[str, Any],
mcp_client: MCPClient,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition.
Args:
tool_def: Tool definition from MCP server with name, description, input_schema
mcp_client: MCP client instance for calling the tool
Returns:
LangChain StructuredTool instance
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
# Log the actual schema for debugging
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
# Create dynamic input model from schema
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
# Normalize Gemini-transformed field names back to MCP schema
# Gemini transforms: entityType->type, from/to->fromEntity/toEntity, relationType->relation
normalized_kwargs = _normalize_gemini_params(kwargs, input_schema)
try:
# Connect to server and call tool
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, normalized_kwargs)
return str(result)
except Exception as e:
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
# Create StructuredTool with response_format to preserve exact schema
tool = StructuredTool(
name=tool_name,
description=tool_description,
coroutine=mcp_tool_call,
args_schema=input_model,
# Store the original MCP schema as metadata so we can access it later
metadata={"mcp_input_schema": input_schema},
)
logger.info(f"Created MCP tool: '{tool_name}'")
return tool
async def load_mcp_tools(
session: AsyncSession, search_space_id: int,
) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors.
This discovers tools dynamically from MCP servers using the protocol.
Args:
session: Database session
search_space_id: User's search space ID
Returns:
List of LangChain StructuredTool instances
"""
try:
# Fetch all ACTIVE MCP connectors for this search space
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.is_active == True, # Only load active connectors
),
)
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
# Extract server config
config = connector.config or {}
server_config = config.get("server_config", {})
command = server_config.get("command")
args = server_config.get("args", [])
env = server_config.get("env", {})
if not command:
logger.warning(f"MCP connector {connector.id} missing command, skipping")
continue
# Create MCP client
mcp_client = MCPClient(command, args, env)
# Connect and discover tools
async with mcp_client.connect():
tool_definitions = await mcp_client.list_tools()
logger.info(
f"Discovered {len(tool_definitions)} tools from MCP server "
f"'{command}' (connector {connector.id})"
)
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
tools.append(tool)
except Exception as e:
logger.exception(
f"Failed to create tool '{tool_def.get('name')}' "
f"from connector {connector.id}: {e!s}",
)
except Exception as e:
logger.exception(
f"Failed to load tools from MCP connector {connector.id}: {e!s}",
)
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools
except Exception as e:
logger.exception(f"Failed to load MCP tools: {e!s}")
return []

View file

@ -1,5 +1,4 @@
"""
Tools registry for SurfSense deep agent.
"""Tools registry for SurfSense deep agent.
This module provides a registry pattern for managing tools in the SurfSense agent.
It makes it easy for OSS contributors to add new tools by:
@ -37,6 +36,7 @@ Example of adding a new tool:
),
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool
from .display_image import create_display_image_tool
from .knowledge_base import create_search_knowledge_base_tool
from .link_preview import create_link_preview_tool
from .mcp_tool import load_mcp_tools
from .podcast import create_generate_podcast_tool
from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool
@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
@dataclass
class ToolDefinition:
"""
Definition of a tool that can be added to the agent.
"""Definition of a tool that can be added to the agent.
Attributes:
name: Unique identifier for the tool
@ -66,6 +66,7 @@ class ToolDefinition:
factory: Callable that creates the tool. Receives a dict of dependencies.
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
enabled_by_default: Whether the tool is enabled when no explicit config is provided
"""
name: str
@ -178,8 +179,7 @@ def build_tools(
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
) -> list[BaseTool]:
"""
Build the list of tools for the agent.
"""Build the list of tools for the agent.
Args:
dependencies: Dict containing all possible dependencies:
@ -206,6 +206,7 @@ def build_tools(
# Add custom tools
tools = build_tools(deps, additional_tools=[my_custom_tool])
"""
# Determine which tools to enable
if enabled_tools is not None:
@ -226,8 +227,9 @@ def build_tools(
# Check that all required dependencies are provided
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
if missing_deps:
msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
raise ValueError(
f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
msg,
)
# Create the tool
@ -239,3 +241,61 @@ def build_tools(
tools.extend(additional_tools)
return tools
async def build_tools_async(
dependencies: dict[str, Any],
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
include_mcp_tools: bool = True,
) -> list[BaseTool]:
"""Async version of build_tools that also loads MCP tools from database.
Design Note:
This function exists because MCP tools require database queries to load user configs,
while built-in tools are created synchronously from static code.
Alternative: We could make build_tools() itself async and always query the database,
but that would force async everywhere even when only using built-in tools. The current
design keeps the simple case (static tools only) synchronous while supporting dynamic
database-loaded tools through this async wrapper.
Args:
dependencies: Dict containing all possible dependencies
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
disabled_tools: List of tool names to disable (applied after enabled_tools).
additional_tools: Extra tools to add (e.g., custom tools not in registry).
include_mcp_tools: Whether to load user's MCP tools from database.
Returns:
List of configured tool instances ready for the agent, including MCP tools.
"""
# Build standard tools
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
# Load MCP tools if requested and dependencies are available
if (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
):
try:
mcp_tools = await load_mcp_tools(
dependencies["db_session"], dependencies["search_space_id"],
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
)
except Exception as e:
# Log error but don't fail - just continue without MCP tools
logging.exception(f"Failed to load MCP tools: {e!s}")
# Log all tools being returned to agent
logging.info(
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
)
return tools