mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 09:16:22 +02:00
feat: support multiple transport types for MCP server connections, including stdio and HTTP
This commit is contained in:
parent
bb5cb846b3
commit
9625a24475
9 changed files with 435 additions and 191 deletions
|
|
@ -1,6 +1,6 @@
|
|||
"""MCP Client Wrapper.
|
||||
|
||||
This module provides a client for communicating with MCP servers via stdio transport.
|
||||
This module provides a client for communicating with MCP servers via stdio and HTTP transports.
|
||||
It handles server lifecycle management, tool discovery, and tool execution.
|
||||
"""
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ from typing import Any
|
|||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -222,7 +223,7 @@ class MCPClient:
|
|||
async def test_mcp_connection(
|
||||
command: str, args: list[str], env: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server and fetch available tools.
|
||||
"""Test connection to an MCP server via stdio and fetch available tools.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server
|
||||
|
|
@ -249,3 +250,51 @@ async def test_mcp_connection(
|
|||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
|
||||
async def test_mcp_http_connection(
|
||||
url: str, headers: dict[str, str] | None = None, transport: str = "streamable-http"
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server via HTTP and fetch available tools.
|
||||
|
||||
Args:
|
||||
url: URL of the MCP server
|
||||
headers: Optional HTTP headers for authentication
|
||||
transport: Transport type ("streamable-http", "http", or "sse")
|
||||
|
||||
Returns:
|
||||
Dict with connection status and available tools
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info("Testing HTTP MCP connection to: %s (transport: %s)", url, transport)
|
||||
|
||||
# Use streamable HTTP client for all HTTP-based transports
|
||||
async with streamablehttp_client(url, headers=headers or {}) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await session.list_tools()
|
||||
tools = []
|
||||
for tool in response.tools:
|
||||
tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||
})
|
||||
|
||||
logger.info("HTTP MCP connection successful. Found %d tools.", len(tools))
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Connected successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to HTTP MCP server: %s", e, exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@
|
|||
This module creates LangChain tools from MCP servers using the Model Context Protocol.
|
||||
Tools are dynamically discovered from MCP servers - no manual configuration needed.
|
||||
|
||||
Supports both transport types:
|
||||
- stdio: Local process-based MCP servers (command, args, env)
|
||||
- streamable-http/http/sse: Remote HTTP-based MCP servers (url, headers)
|
||||
|
||||
This implements real MCP protocol support similar to Cursor's implementation.
|
||||
"""
|
||||
|
||||
|
|
@ -10,6 +14,8 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel, create_model
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -65,11 +71,11 @@ def _create_dynamic_input_model_from_schema(
|
|||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
||||
async def _create_mcp_tool_from_definition(
|
||||
async def _create_mcp_tool_from_definition_stdio(
|
||||
tool_def: dict[str, Any],
|
||||
mcp_client: MCPClient,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition.
|
||||
"""Create a LangChain tool from an MCP tool definition (stdio transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
|
|
@ -116,13 +122,223 @@ async def _create_mcp_tool_from_definition(
|
|||
coroutine=mcp_tool_call,
|
||||
args_schema=input_model,
|
||||
# Store the original MCP schema as metadata so we can access it later
|
||||
metadata={"mcp_input_schema": input_schema},
|
||||
metadata={"mcp_input_schema": input_schema, "mcp_transport": "stdio"},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool: '{tool_name}'")
|
||||
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
|
||||
return tool
|
||||
|
||||
|
||||
async def _create_mcp_tool_from_definition_http(
|
||||
tool_def: dict[str, Any],
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
url: URL of the MCP server
|
||||
headers: HTTP headers for authentication
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
try:
|
||||
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Call the tool
|
||||
response = await session.call_tool(tool_name, arguments=kwargs)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
result.append(str(content.data))
|
||||
else:
|
||||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}")
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_http_tool_call,
|
||||
args_schema=input_model,
|
||||
metadata={"mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
|
||||
return tool
|
||||
|
||||
|
||||
async def _load_stdio_mcp_tools(
|
||||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from a stdio-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with command, args, env
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required command field
|
||||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Validate args field (must be list if present)
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Validate env field (must be dict if present)
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from stdio MCP server "
|
||||
f"'{command}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_stdio(tool_def, mcp_client)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def _load_http_mcp_tools(
|
||||
connector_id: int,
|
||||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
Args:
|
||||
connector_id: Connector ID for logging
|
||||
connector_name: Connector name for logging
|
||||
server_config: Server configuration with url, headers
|
||||
|
||||
Returns:
|
||||
List of tools from the MCP server
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
# Validate required url field
|
||||
url = server_config.get("url")
|
||||
if not url or not isinstance(url, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Validate headers field (must be dict if present)
|
||||
headers = server_config.get("headers", {})
|
||||
if not isinstance(headers, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping"
|
||||
)
|
||||
return tools
|
||||
|
||||
# Connect and discover tools via HTTP
|
||||
try:
|
||||
async with streamablehttp_client(url, headers=headers) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await session.list_tools()
|
||||
tool_definitions = []
|
||||
for tool in response.tools:
|
||||
tool_definitions.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
|
||||
f"'{url}' (connector {connector_id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition_http(tool_def, url, headers)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create HTTP tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}"
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
|
|
@ -130,6 +346,7 @@ async def load_mcp_tools(
|
|||
"""Load all MCP tools from user's active MCP server connectors.
|
||||
|
||||
This discovers tools dynamically from MCP servers using the protocol.
|
||||
Supports both stdio (local process) and HTTP (remote server) transports.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
|
@ -163,54 +380,22 @@ async def load_mcp_tools(
|
|||
)
|
||||
continue
|
||||
|
||||
# Validate required command field
|
||||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') missing or invalid command field, skipping"
|
||||
# Determine transport type
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
# HTTP-based MCP server
|
||||
connector_tools = await _load_http_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate args field (must be list if present)
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid args field (must be list), skipping"
|
||||
else:
|
||||
# stdio-based MCP server (default)
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
connector.id, connector.name, server_config
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate env field (must be dict if present)
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid env field (must be dict), skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from MCP server "
|
||||
f"'{command}' (connector {connector.id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition(
|
||||
tool_def, mcp_client
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector.id}: {e!s}"
|
||||
)
|
||||
|
||||
tools.extend(connector_tools)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue