mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-31 19:45:15 +02:00
Merge pull request #693 from manojag115/feat/mcp-connector-backend
Add backend code for adding MCPs as a connector
This commit is contained in:
commit
4fb9017eac
13 changed files with 990 additions and 32 deletions
|
|
@ -0,0 +1,50 @@
|
||||||
|
"""allow_multiple_connectors_with_unique_names
|
||||||
|
|
||||||
|
Revision ID: 5263aa4e7f94
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2026-01-13 12:23:31.481643
|
||||||
|
|
||||||
|
"""
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '5263aa4e7f94'
|
||||||
|
down_revision: str | None = 'a1b2c3d4e5f6'
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# Drop the old unique constraint
|
||||||
|
op.drop_constraint(
|
||||||
|
'uq_searchspace_user_connector_type',
|
||||||
|
'search_source_connectors',
|
||||||
|
type_='unique'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new unique constraint that includes name
|
||||||
|
op.create_unique_constraint(
|
||||||
|
'uq_searchspace_user_connector_type_name',
|
||||||
|
'search_source_connectors',
|
||||||
|
['search_space_id', 'user_id', 'connector_type', 'name']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# Drop the new constraint
|
||||||
|
op.drop_constraint(
|
||||||
|
'uq_searchspace_user_connector_type_name',
|
||||||
|
'search_source_connectors',
|
||||||
|
type_='unique'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore the old constraint
|
||||||
|
op.create_unique_constraint(
|
||||||
|
'uq_searchspace_user_connector_type',
|
||||||
|
'search_source_connectors',
|
||||||
|
['search_space_id', 'user_id', 'connector_type']
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""Add MCP connector type
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises: 61
|
||||||
|
Create Date: 2026-01-09 15:19:51.827647
|
||||||
|
|
||||||
|
"""
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'a1b2c3d4e5f6'
|
||||||
|
down_revision: str | None = '61'
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add MCP_CONNECTOR to SearchSourceConnectorType enum."""
|
||||||
|
# Add new enum value using raw SQL
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'MCP_CONNECTOR';
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove MCP_CONNECTOR from SearchSourceConnectorType enum."""
|
||||||
|
# Note: PostgreSQL does not support removing enum values directly.
|
||||||
|
# To downgrade, you would need to:
|
||||||
|
# 1. Create a new enum without MCP_CONNECTOR
|
||||||
|
# 2. Alter the column to use the new enum
|
||||||
|
# 3. Drop the old enum
|
||||||
|
# This is left as a manual operation if needed.
|
||||||
|
pass
|
||||||
|
|
@ -20,7 +20,7 @@ from app.agents.new_chat.system_prompt import (
|
||||||
build_configurable_system_prompt,
|
build_configurable_system_prompt,
|
||||||
build_surfsense_system_prompt,
|
build_surfsense_system_prompt,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.tools import build_tools
|
from app.agents.new_chat.tools.registry import build_tools_async
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -28,7 +28,7 @@ from app.services.connector_service import ConnectorService
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def create_surfsense_deep_agent(
|
async def create_surfsense_deep_agent(
|
||||||
llm: ChatLiteLLM,
|
llm: ChatLiteLLM,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
|
|
@ -120,8 +120,8 @@ def create_surfsense_deep_agent(
|
||||||
"firecrawl_api_key": firecrawl_api_key,
|
"firecrawl_api_key": firecrawl_api_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build tools using the registry
|
# Build tools using the async registry (includes MCP tools)
|
||||||
tools = build_tools(
|
tools = await build_tools_async(
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
enabled_tools=enabled_tools,
|
enabled_tools=enabled_tools,
|
||||||
disabled_tools=disabled_tools,
|
disabled_tools=disabled_tools,
|
||||||
|
|
|
||||||
188
surfsense_backend/app/agents/new_chat/tools/mcp_client.py
Normal file
188
surfsense_backend/app/agents/new_chat/tools/mcp_client.py
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
"""MCP Client Wrapper.
|
||||||
|
|
||||||
|
This module provides a client for communicating with MCP servers via stdio transport.
|
||||||
|
It handles server lifecycle management, tool discovery, and tool execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""Client for communicating with an MCP server."""
|
||||||
|
|
||||||
|
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
|
||||||
|
"""Initialize MCP client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Command to spawn the MCP server (e.g., "uvx", "node")
|
||||||
|
args: Arguments for the command (e.g., ["mcp-server-git"])
|
||||||
|
env: Optional environment variables for the server process
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.command = command
|
||||||
|
self.args = args
|
||||||
|
self.env = env or {}
|
||||||
|
self.session: ClientSession | None = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to the MCP server and manage its lifecycle.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
ClientSession: Active MCP session for making requests
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Merge env vars with current environment
|
||||||
|
server_env = os.environ.copy()
|
||||||
|
server_env.update(self.env)
|
||||||
|
|
||||||
|
# Create server parameters with env
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command=self.command,
|
||||||
|
args=self.args,
|
||||||
|
env=server_env
|
||||||
|
)
|
||||||
|
|
||||||
|
# Spawn server process and create session
|
||||||
|
# Note: Cannot combine these context managers because ClientSession
|
||||||
|
# needs the read/write streams from stdio_client
|
||||||
|
async with stdio_client(server=server_params) as (read, write):
|
||||||
|
async with ClientSession(read, write) as session:
|
||||||
|
# Initialize the connection
|
||||||
|
await session.initialize()
|
||||||
|
self.session = session
|
||||||
|
logger.info(
|
||||||
|
"Connected to MCP server: %s %s",
|
||||||
|
self.command,
|
||||||
|
" ".join(self.args),
|
||||||
|
)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.session = None
|
||||||
|
logger.info("Disconnected from MCP server: %s", self.command)
|
||||||
|
|
||||||
|
async def list_tools(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all tools available from the MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool definitions with name, description, and input schema
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If not connected to server
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not self.session:
|
||||||
|
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call tools/list RPC method
|
||||||
|
response = await self.session.list_tools()
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool in response.tools:
|
||||||
|
tools.append({
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description or "",
|
||||||
|
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info("Listed %d tools from MCP server", len(tools))
|
||||||
|
return tools
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||||
|
"""Call a tool on the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to call
|
||||||
|
arguments: Arguments to pass to the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If not connected to server
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not self.session:
|
||||||
|
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Calling MCP tool '%s' with arguments: %s", tool_name, arguments)
|
||||||
|
|
||||||
|
# Call tools/call RPC method
|
||||||
|
response = await self.session.call_tool(tool_name, arguments=arguments)
|
||||||
|
|
||||||
|
# Extract content from response
|
||||||
|
result = []
|
||||||
|
for content in response.content:
|
||||||
|
if hasattr(content, "text"):
|
||||||
|
result.append(content.text)
|
||||||
|
elif hasattr(content, "data"):
|
||||||
|
result.append(str(content.data))
|
||||||
|
else:
|
||||||
|
result.append(str(content))
|
||||||
|
|
||||||
|
result_str = "\n".join(result) if result else ""
|
||||||
|
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||||
|
return result_str
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Handle validation errors from MCP server responses
|
||||||
|
# Some MCP servers (like server-memory) return extra fields not in their schema
|
||||||
|
if "Invalid structured content" in str(e):
|
||||||
|
logger.warning("MCP server returned data not matching its schema, but continuing: %s", e)
|
||||||
|
# Try to extract result from error message or return a success message
|
||||||
|
return "Operation completed (server returned unexpected format)"
|
||||||
|
raise
|
||||||
|
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||||
|
logger.error("Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True)
|
||||||
|
return f"Error calling tool: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mcp_connection(
|
||||||
|
command: str, args: list[str], env: dict[str, str] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Test connection to an MCP server and fetch available tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Command to spawn the MCP server
|
||||||
|
args: Arguments for the command
|
||||||
|
env: Optional environment variables
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with connection status and available tools
|
||||||
|
|
||||||
|
"""
|
||||||
|
client = MCPClient(command, args, env)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with client.connect():
|
||||||
|
tools = await client.list_tools()
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"Connected successfully. Found {len(tools)} tools.",
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Failed to connect: {e!s}",
|
||||||
|
"tools": [],
|
||||||
|
}
|
||||||
189
surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
Normal file
189
surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
Normal file
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""MCP Tool Factory.
|
||||||
|
|
||||||
|
This module creates LangChain tools from MCP servers using the Model Context Protocol.
|
||||||
|
Tools are dynamically discovered from MCP servers - no manual configuration needed.
|
||||||
|
|
||||||
|
This implements real MCP protocol support similar to Cursor's implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from pydantic import BaseModel, create_model
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||||
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dynamic_input_model_from_schema(
|
||||||
|
tool_name: str, input_schema: dict[str, Any],
|
||||||
|
) -> type[BaseModel]:
|
||||||
|
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool (used for model class name)
|
||||||
|
input_schema: JSON schema from MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pydantic model class for tool input validation
|
||||||
|
|
||||||
|
"""
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required_fields = input_schema.get("required", [])
|
||||||
|
|
||||||
|
# Build Pydantic field definitions
|
||||||
|
field_definitions = {}
|
||||||
|
for param_name, param_schema in properties.items():
|
||||||
|
param_description = param_schema.get("description", "")
|
||||||
|
is_required = param_name in required_fields
|
||||||
|
|
||||||
|
# Use Any type for complex schemas to preserve structure
|
||||||
|
# This allows the MCP server to do its own validation
|
||||||
|
from typing import Any as AnyType
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
if is_required:
|
||||||
|
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
|
||||||
|
else:
|
||||||
|
field_definitions[param_name] = (
|
||||||
|
AnyType | None,
|
||||||
|
Field(None, description=param_description),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dynamic model
|
||||||
|
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||||
|
return create_model(model_name, **field_definitions)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_mcp_tool_from_definition(
|
||||||
|
tool_def: dict[str, Any],
|
||||||
|
mcp_client: MCPClient,
|
||||||
|
) -> StructuredTool:
|
||||||
|
"""Create a LangChain tool from an MCP tool definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||||
|
mcp_client: MCP client instance for calling the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LangChain StructuredTool instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
tool_name = tool_def.get("name", "unnamed_tool")
|
||||||
|
tool_description = tool_def.get("description", "No description provided")
|
||||||
|
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||||
|
|
||||||
|
# Log the actual schema for debugging
|
||||||
|
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||||
|
|
||||||
|
# Create dynamic input model from schema
|
||||||
|
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||||
|
|
||||||
|
async def mcp_tool_call(**kwargs) -> str:
|
||||||
|
"""Execute the MCP tool call via the client."""
|
||||||
|
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect to server and call tool
|
||||||
|
async with mcp_client.connect():
|
||||||
|
result = await mcp_client.call_tool(tool_name, kwargs)
|
||||||
|
return str(result)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
|
||||||
|
logger.exception(error_msg)
|
||||||
|
return f"Error: {error_msg}"
|
||||||
|
|
||||||
|
# Create StructuredTool with response_format to preserve exact schema
|
||||||
|
tool = StructuredTool(
|
||||||
|
name=tool_name,
|
||||||
|
description=tool_description,
|
||||||
|
coroutine=mcp_tool_call,
|
||||||
|
args_schema=input_model,
|
||||||
|
# Store the original MCP schema as metadata so we can access it later
|
||||||
|
metadata={"mcp_input_schema": input_schema},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created MCP tool: '{tool_name}'")
|
||||||
|
return tool
|
||||||
|
|
||||||
|
|
||||||
|
async def load_mcp_tools(
|
||||||
|
session: AsyncSession, search_space_id: int,
|
||||||
|
) -> list[StructuredTool]:
|
||||||
|
"""Load all MCP tools from user's active MCP server connectors.
|
||||||
|
|
||||||
|
This discovers tools dynamically from MCP servers using the protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
search_space_id: User's search space ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LangChain StructuredTool instances
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Fetch all MCP connectors for this search space
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tools: list[StructuredTool] = []
|
||||||
|
for connector in result.scalars():
|
||||||
|
try:
|
||||||
|
# Extract server config
|
||||||
|
config = connector.config or {}
|
||||||
|
server_config = config.get("server_config", {})
|
||||||
|
|
||||||
|
command = server_config.get("command")
|
||||||
|
args = server_config.get("args", [])
|
||||||
|
env = server_config.get("env", {})
|
||||||
|
|
||||||
|
if not command:
|
||||||
|
logger.warning(f"MCP connector {connector.id} missing command, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create MCP client
|
||||||
|
mcp_client = MCPClient(command, args, env)
|
||||||
|
|
||||||
|
# Connect and discover tools
|
||||||
|
async with mcp_client.connect():
|
||||||
|
tool_definitions = await mcp_client.list_tools()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Discovered {len(tool_definitions)} tools from MCP server "
|
||||||
|
f"'{command}' (connector {connector.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create LangChain tools from definitions
|
||||||
|
for tool_def in tool_definitions:
|
||||||
|
try:
|
||||||
|
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
|
||||||
|
tools.append(tool)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to create tool '{tool_def.get('name')}' "
|
||||||
|
f"from connector {connector.id}: {e!s}",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to load tools from MCP connector {connector.id}: {e!s}",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||||
|
return tools
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to load MCP tools: {e!s}")
|
||||||
|
return []
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
"""
|
"""Tools registry for SurfSense deep agent.
|
||||||
Tools registry for SurfSense deep agent.
|
|
||||||
|
|
||||||
This module provides a registry pattern for managing tools in the SurfSense agent.
|
This module provides a registry pattern for managing tools in the SurfSense agent.
|
||||||
It makes it easy for OSS contributors to add new tools by:
|
It makes it easy for OSS contributors to add new tools by:
|
||||||
|
|
@ -37,6 +36,7 @@ Example of adding a new tool:
|
||||||
),
|
),
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool
|
||||||
from .display_image import create_display_image_tool
|
from .display_image import create_display_image_tool
|
||||||
from .knowledge_base import create_search_knowledge_base_tool
|
from .knowledge_base import create_search_knowledge_base_tool
|
||||||
from .link_preview import create_link_preview_tool
|
from .link_preview import create_link_preview_tool
|
||||||
|
from .mcp_tool import load_mcp_tools
|
||||||
from .podcast import create_generate_podcast_tool
|
from .podcast import create_generate_podcast_tool
|
||||||
from .scrape_webpage import create_scrape_webpage_tool
|
from .scrape_webpage import create_scrape_webpage_tool
|
||||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
|
|
@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolDefinition:
|
class ToolDefinition:
|
||||||
"""
|
"""Definition of a tool that can be added to the agent.
|
||||||
Definition of a tool that can be added to the agent.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
name: Unique identifier for the tool
|
name: Unique identifier for the tool
|
||||||
|
|
@ -66,6 +66,7 @@ class ToolDefinition:
|
||||||
factory: Callable that creates the tool. Receives a dict of dependencies.
|
factory: Callable that creates the tool. Receives a dict of dependencies.
|
||||||
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
||||||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
@ -178,8 +179,7 @@ def build_tools(
|
||||||
disabled_tools: list[str] | None = None,
|
disabled_tools: list[str] | None = None,
|
||||||
additional_tools: list[BaseTool] | None = None,
|
additional_tools: list[BaseTool] | None = None,
|
||||||
) -> list[BaseTool]:
|
) -> list[BaseTool]:
|
||||||
"""
|
"""Build the list of tools for the agent.
|
||||||
Build the list of tools for the agent.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dependencies: Dict containing all possible dependencies:
|
dependencies: Dict containing all possible dependencies:
|
||||||
|
|
@ -206,6 +206,7 @@ def build_tools(
|
||||||
|
|
||||||
# Add custom tools
|
# Add custom tools
|
||||||
tools = build_tools(deps, additional_tools=[my_custom_tool])
|
tools = build_tools(deps, additional_tools=[my_custom_tool])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Determine which tools to enable
|
# Determine which tools to enable
|
||||||
if enabled_tools is not None:
|
if enabled_tools is not None:
|
||||||
|
|
@ -226,8 +227,9 @@ def build_tools(
|
||||||
# Check that all required dependencies are provided
|
# Check that all required dependencies are provided
|
||||||
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
|
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
|
||||||
if missing_deps:
|
if missing_deps:
|
||||||
|
msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
|
msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the tool
|
# Create the tool
|
||||||
|
|
@ -239,3 +241,61 @@ def build_tools(
|
||||||
tools.extend(additional_tools)
|
tools.extend(additional_tools)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
async def build_tools_async(
|
||||||
|
dependencies: dict[str, Any],
|
||||||
|
enabled_tools: list[str] | None = None,
|
||||||
|
disabled_tools: list[str] | None = None,
|
||||||
|
additional_tools: list[BaseTool] | None = None,
|
||||||
|
include_mcp_tools: bool = True,
|
||||||
|
) -> list[BaseTool]:
|
||||||
|
"""Async version of build_tools that also loads MCP tools from database.
|
||||||
|
|
||||||
|
Design Note:
|
||||||
|
This function exists because MCP tools require database queries to load user configs,
|
||||||
|
while built-in tools are created synchronously from static code.
|
||||||
|
|
||||||
|
Alternative: We could make build_tools() itself async and always query the database,
|
||||||
|
but that would force async everywhere even when only using built-in tools. The current
|
||||||
|
design keeps the simple case (static tools only) synchronous while supporting dynamic
|
||||||
|
database-loaded tools through this async wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dependencies: Dict containing all possible dependencies
|
||||||
|
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
|
||||||
|
disabled_tools: List of tool names to disable (applied after enabled_tools).
|
||||||
|
additional_tools: Extra tools to add (e.g., custom tools not in registry).
|
||||||
|
include_mcp_tools: Whether to load user's MCP tools from database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of configured tool instances ready for the agent, including MCP tools.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Build standard tools
|
||||||
|
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
||||||
|
|
||||||
|
# Load MCP tools if requested and dependencies are available
|
||||||
|
if (
|
||||||
|
include_mcp_tools
|
||||||
|
and "db_session" in dependencies
|
||||||
|
and "search_space_id" in dependencies
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
mcp_tools = await load_mcp_tools(
|
||||||
|
dependencies["db_session"], dependencies["search_space_id"],
|
||||||
|
)
|
||||||
|
tools.extend(mcp_tools)
|
||||||
|
logging.info(
|
||||||
|
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't fail - just continue without MCP tools
|
||||||
|
logging.exception(f"Failed to load MCP tools: {e!s}")
|
||||||
|
|
||||||
|
# Log all tools being returned to agent
|
||||||
|
logging.info(
|
||||||
|
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ class SearchSourceConnectorType(str, Enum):
|
||||||
WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR"
|
WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR"
|
||||||
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
|
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
|
||||||
CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR"
|
CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR"
|
||||||
|
MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(str, Enum):
|
class LiteLLMProvider(str, Enum):
|
||||||
|
|
@ -605,7 +606,8 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||||
"search_space_id",
|
"search_space_id",
|
||||||
"user_id",
|
"user_id",
|
||||||
"connector_type",
|
"connector_type",
|
||||||
name="uq_searchspace_user_connector_type",
|
"name",
|
||||||
|
name="uq_searchspace_user_connector_type_name",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,13 @@ PUT /search-source-connectors/{connector_id} - Update a specific connector
|
||||||
DELETE /search-source-connectors/{connector_id} - Delete a specific connector
|
DELETE /search-source-connectors/{connector_id} - Delete a specific connector
|
||||||
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
|
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
|
||||||
|
|
||||||
|
MCP (Model Context Protocol) Connector routes:
|
||||||
|
POST /connectors/mcp - Create a new MCP connector with custom API tools
|
||||||
|
GET /connectors/mcp - List all MCP connectors for the current user's search space
|
||||||
|
GET /connectors/mcp/{connector_id} - Get a specific MCP connector with tools config
|
||||||
|
PUT /connectors/mcp/{connector_id} - Update an MCP connector's tools config
|
||||||
|
DELETE /connectors/mcp/{connector_id} - Delete an MCP connector
|
||||||
|
|
||||||
Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space.
|
Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space.
|
||||||
Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space.
|
Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space.
|
||||||
"""
|
"""
|
||||||
|
|
@ -32,6 +39,9 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
GoogleDriveIndexRequest,
|
GoogleDriveIndexRequest,
|
||||||
|
MCPConnectorCreate,
|
||||||
|
MCPConnectorRead,
|
||||||
|
MCPConnectorUpdate,
|
||||||
SearchSourceConnectorBase,
|
SearchSourceConnectorBase,
|
||||||
SearchSourceConnectorCreate,
|
SearchSourceConnectorCreate,
|
||||||
SearchSourceConnectorRead,
|
SearchSourceConnectorRead,
|
||||||
|
|
@ -127,18 +137,20 @@ async def create_search_source_connector(
|
||||||
|
|
||||||
# Check if a connector with the same type already exists for this search space
|
# Check if a connector with the same type already exists for this search space
|
||||||
# (for non-OAuth connectors that don't support multiple accounts)
|
# (for non-OAuth connectors that don't support multiple accounts)
|
||||||
result = await session.execute(
|
# Exception: MCP_CONNECTOR can have multiple instances with different names
|
||||||
select(SearchSourceConnector).filter(
|
if connector.connector_type != SearchSourceConnectorType.MCP_CONNECTOR:
|
||||||
SearchSourceConnector.search_space_id == search_space_id,
|
result = await session.execute(
|
||||||
SearchSourceConnector.connector_type == connector.connector_type,
|
select(SearchSourceConnector).filter(
|
||||||
)
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
)
|
SearchSourceConnector.connector_type == connector.connector_type,
|
||||||
existing_connector = result.scalars().first()
|
)
|
||||||
if existing_connector:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409,
|
|
||||||
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
|
|
||||||
)
|
)
|
||||||
|
existing_connector = result.scalars().first()
|
||||||
|
if existing_connector:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare connector data
|
# Prepare connector data
|
||||||
connector_data = connector.model_dump()
|
connector_data = connector.model_dump()
|
||||||
|
|
@ -1964,3 +1976,348 @@ async def run_bookstack_indexing(
|
||||||
f"Critical error in run_bookstack_indexing for connector {connector_id}: {e}",
|
f"Critical error in run_bookstack_indexing for connector {connector_id}: {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MCP Connector Routes
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/connectors/mcp", response_model=MCPConnectorRead, status_code=201)
|
||||||
|
async def create_mcp_connector(
|
||||||
|
connector_data: MCPConnectorCreate,
|
||||||
|
search_space_id: int = Query(..., description="Search space ID"),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new MCP (Model Context Protocol) connector.
|
||||||
|
|
||||||
|
MCP connectors allow users to connect to MCP servers (like in Cursor).
|
||||||
|
Tools are auto-discovered from the server - no manual configuration needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_data: MCP server configuration (command, args, env)
|
||||||
|
search_space_id: ID of the search space to attach the connector to
|
||||||
|
session: Database session
|
||||||
|
user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created MCP connector with server configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If search space not found or permission denied
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check user has permission to create connectors
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
search_space_id,
|
||||||
|
Permission.CONNECTORS_CREATE.value,
|
||||||
|
"You don't have permission to create connectors in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the connector with server config
|
||||||
|
db_connector = SearchSourceConnector(
|
||||||
|
name=connector_data.name,
|
||||||
|
connector_type=SearchSourceConnectorType.MCP_CONNECTOR,
|
||||||
|
is_indexable=False, # MCP connectors are not indexable
|
||||||
|
config={"server_config": connector_data.server_config.model_dump()},
|
||||||
|
periodic_indexing_enabled=False,
|
||||||
|
indexing_frequency_minutes=None,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(db_connector)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(db_connector)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created MCP connector {db_connector.id} for server '{connector_data.server_config.command}' "
|
||||||
|
f"for user {user.id} in search space {search_space_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to read schema
|
||||||
|
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||||
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create MCP connector: {e!s}", exc_info=True)
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to create MCP connector: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/connectors/mcp", response_model=list[MCPConnectorRead])
|
||||||
|
async def list_mcp_connectors(
|
||||||
|
search_space_id: int = Query(..., description="Search space ID"),
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all MCP connectors for a search space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_space_id: ID of the search space
|
||||||
|
session: Database session
|
||||||
|
user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MCP connectors with their tool configurations
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check user has permission to read connectors
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
search_space_id,
|
||||||
|
Permission.CONNECTORS_READ.value,
|
||||||
|
"You don't have permission to view connectors in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch MCP connectors
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||||
|
SearchSourceConnector.search_space_id == search_space_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
connectors = result.scalars().all()
|
||||||
|
return [
|
||||||
|
MCPConnectorRead.from_connector(SearchSourceConnectorRead.model_validate(c))
|
||||||
|
for c in connectors
|
||||||
|
]
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list MCP connectors: {e!s}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to list MCP connectors: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
|
||||||
|
async def get_mcp_connector(
|
||||||
|
connector_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a specific MCP connector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_id: ID of the connector
|
||||||
|
session: Database session
|
||||||
|
user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCP connector with tool configurations
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Fetch connector
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||||
|
|
||||||
|
# Check user has permission to read connectors
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
connector.search_space_id,
|
||||||
|
Permission.CONNECTORS_READ.value,
|
||||||
|
"You don't have permission to view this connector",
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||||
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get MCP connector: {e!s}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to get MCP connector: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
|
||||||
|
async def update_mcp_connector(
|
||||||
|
connector_id: int,
|
||||||
|
connector_update: MCPConnectorUpdate,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update an MCP connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_id: ID of the connector to update
|
||||||
|
connector_update: Updated connector data
|
||||||
|
session: Database session
|
||||||
|
user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated MCP connector
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Fetch connector
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||||
|
|
||||||
|
# Check user has permission to update connectors
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
connector.search_space_id,
|
||||||
|
Permission.CONNECTORS_UPDATE.value,
|
||||||
|
"You don't have permission to update this connector",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
if connector_update.name is not None:
|
||||||
|
connector.name = connector_update.name
|
||||||
|
|
||||||
|
if connector_update.server_config is not None:
|
||||||
|
connector.config = {
|
||||||
|
"server_config": connector_update.server_config.model_dump()
|
||||||
|
}
|
||||||
|
|
||||||
|
connector.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(connector)
|
||||||
|
|
||||||
|
logger.info(f"Updated MCP connector {connector_id}")
|
||||||
|
|
||||||
|
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||||
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update MCP connector: {e!s}", exc_info=True)
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to update MCP connector: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/connectors/mcp/{connector_id}", status_code=204)
|
||||||
|
async def delete_mcp_connector(
|
||||||
|
connector_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete an MCP connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_id: ID of the connector to delete
|
||||||
|
session: Database session
|
||||||
|
user: Current authenticated user
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Fetch connector
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
SearchSourceConnector.connector_type
|
||||||
|
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||||
|
|
||||||
|
# Check user has permission to delete connectors
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
connector.search_space_id,
|
||||||
|
Permission.CONNECTORS_DELETE.value,
|
||||||
|
"You don't have permission to delete this connector",
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.delete(connector)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Deleted MCP connector {connector_id}")
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete MCP connector: {e!s}", exc_info=True)
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to delete MCP connector: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/connectors/mcp/test")
|
||||||
|
async def test_mcp_server_connection(
|
||||||
|
server_config: dict = Body(...),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test connection to an MCP server and fetch available tools.
|
||||||
|
|
||||||
|
This endpoint allows users to test their MCP server configuration
|
||||||
|
before saving it, similar to Cursor's flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_config: Server configuration with command, args, env
|
||||||
|
user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Connection status and list of available tools
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.agents.new_chat.tools.mcp_client import test_mcp_connection
|
||||||
|
|
||||||
|
command = server_config.get("command")
|
||||||
|
args = server_config.get("args", [])
|
||||||
|
env = server_config.get("env", {})
|
||||||
|
|
||||||
|
if not command:
|
||||||
|
raise HTTPException(status_code=400, detail="Server command is required")
|
||||||
|
|
||||||
|
# Test the connection
|
||||||
|
result = await test_mcp_connection(command, args, env)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to test MCP connection: {e!s}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Failed to test connection: {e!s}",
|
||||||
|
"tools": [],
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,10 @@ from .rbac_schemas import (
|
||||||
UserSearchSpaceAccess,
|
UserSearchSpaceAccess,
|
||||||
)
|
)
|
||||||
from .search_source_connector import (
|
from .search_source_connector import (
|
||||||
|
MCPConnectorCreate,
|
||||||
|
MCPConnectorRead,
|
||||||
|
MCPConnectorUpdate,
|
||||||
|
MCPServerConfig,
|
||||||
SearchSourceConnectorBase,
|
SearchSourceConnectorBase,
|
||||||
SearchSourceConnectorCreate,
|
SearchSourceConnectorCreate,
|
||||||
SearchSourceConnectorRead,
|
SearchSourceConnectorRead,
|
||||||
|
|
@ -108,6 +112,11 @@ __all__ = [
|
||||||
"LogFilter",
|
"LogFilter",
|
||||||
"LogRead",
|
"LogRead",
|
||||||
"LogUpdate",
|
"LogUpdate",
|
||||||
|
# Search source connector schemas
|
||||||
|
"MCPConnectorCreate",
|
||||||
|
"MCPConnectorRead",
|
||||||
|
"MCPConnectorUpdate",
|
||||||
|
"MCPServerConfig",
|
||||||
"MembershipRead",
|
"MembershipRead",
|
||||||
"MembershipReadWithUser",
|
"MembershipReadWithUser",
|
||||||
"MembershipUpdate",
|
"MembershipUpdate",
|
||||||
|
|
@ -135,7 +144,6 @@ __all__ = [
|
||||||
"RoleCreate",
|
"RoleCreate",
|
||||||
"RoleRead",
|
"RoleRead",
|
||||||
"RoleUpdate",
|
"RoleUpdate",
|
||||||
# Search source connector schemas
|
|
||||||
"SearchSourceConnectorBase",
|
"SearchSourceConnectorBase",
|
||||||
"SearchSourceConnectorCreate",
|
"SearchSourceConnectorCreate",
|
||||||
"SearchSourceConnectorRead",
|
"SearchSourceConnectorRead",
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class SearchSourceConnectorBase(BaseModel):
|
||||||
@field_validator("config")
|
@field_validator("config")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config_for_connector_type(
|
def validate_config_for_connector_type(
|
||||||
cls, config: dict[str, Any], values: dict[str, Any]
|
cls, config: dict[str, Any], values: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
connector_type = values.data.get("connector_type")
|
connector_type = values.data.get("connector_type")
|
||||||
return validate_connector_config(connector_type, config)
|
return validate_connector_config(connector_type, config)
|
||||||
|
|
@ -38,15 +38,18 @@ class SearchSourceConnectorBase(BaseModel):
|
||||||
"""
|
"""
|
||||||
if self.periodic_indexing_enabled:
|
if self.periodic_indexing_enabled:
|
||||||
if not self.is_indexable:
|
if not self.is_indexable:
|
||||||
|
msg = "periodic_indexing_enabled can only be True for indexable connectors"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"periodic_indexing_enabled can only be True for indexable connectors"
|
msg,
|
||||||
)
|
)
|
||||||
if self.indexing_frequency_minutes is None:
|
if self.indexing_frequency_minutes is None:
|
||||||
|
msg = "indexing_frequency_minutes is required when periodic_indexing_enabled is True"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"indexing_frequency_minutes is required when periodic_indexing_enabled is True"
|
msg,
|
||||||
)
|
)
|
||||||
if self.indexing_frequency_minutes <= 0:
|
if self.indexing_frequency_minutes <= 0:
|
||||||
raise ValueError("indexing_frequency_minutes must be greater than 0")
|
msg = "indexing_frequency_minutes must be greater than 0"
|
||||||
|
raise ValueError(msg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,3 +73,63 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MCP-specific schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerConfig(BaseModel):
|
||||||
|
"""Configuration for an MCP server connection (similar to Cursor's config)."""
|
||||||
|
|
||||||
|
command: str # e.g., "uvx", "node", "python"
|
||||||
|
args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"]
|
||||||
|
env: dict[str, str] = {} # Environment variables for the server process
|
||||||
|
transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectorCreate(BaseModel):
|
||||||
|
"""Schema for creating an MCP connector."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
server_config: MCPServerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectorUpdate(BaseModel):
|
||||||
|
"""Schema for updating an MCP connector."""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
server_config: MCPServerConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectorRead(BaseModel):
|
||||||
|
"""Schema for reading an MCP connector with server config."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
connector_type: SearchSourceConnectorType
|
||||||
|
server_config: MCPServerConfig
|
||||||
|
search_space_id: int
|
||||||
|
user_id: uuid.UUID
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead":
|
||||||
|
"""Convert from base SearchSourceConnectorRead."""
|
||||||
|
config = connector.config or {}
|
||||||
|
server_config = MCPServerConfig(**config.get("server_config", {}))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=connector.id,
|
||||||
|
name=connector.name,
|
||||||
|
connector_type=connector.connector_type,
|
||||||
|
server_config=server_config,
|
||||||
|
search_space_id=connector.search_space_id,
|
||||||
|
user_id=connector.user_id,
|
||||||
|
created_at=connector.created_at,
|
||||||
|
updated_at=connector.updated_at,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -237,7 +237,7 @@ async def stream_new_chat(
|
||||||
checkpointer = await get_checkpointer()
|
checkpointer = await get_checkpointer()
|
||||||
|
|
||||||
# Create the deep agent with checkpointer and configurable prompts
|
# Create the deep agent with checkpointer and configurable prompts
|
||||||
agent = create_surfsense_deep_agent(
|
agent = await create_surfsense_deep_agent(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
db_session=session,
|
db_session=session,
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,9 @@ from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
|
|
@ -27,6 +27,7 @@ BASE_NAME_FOR_TYPE = {
|
||||||
SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord",
|
SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord",
|
||||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence",
|
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence",
|
||||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable",
|
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable",
|
||||||
|
SearchSourceConnectorType.MCP_CONNECTOR: "Model Context Protocol (MCP)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -75,7 +76,7 @@ def extract_identifier_from_credentials(
|
||||||
if ".atlassian.net" in hostname:
|
if ".atlassian.net" in hostname:
|
||||||
return hostname.replace(".atlassian.net", "")
|
return hostname.replace(".atlassian.net", "")
|
||||||
return hostname
|
return hostname
|
||||||
except Exception:
|
except (ValueError, TypeError, AttributeError):
|
||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,9 @@ dependencies = [
|
||||||
"chonkie[all]>=1.5.0",
|
"chonkie[all]>=1.5.0",
|
||||||
"langgraph-checkpoint-postgres>=3.0.2",
|
"langgraph-checkpoint-postgres>=3.0.2",
|
||||||
"psycopg[binary,pool]>=3.3.2",
|
"psycopg[binary,pool]>=3.3.2",
|
||||||
|
"mcp>=1.25.0",
|
||||||
|
"starlette>=0.40.0,<0.51.0",
|
||||||
|
"sse-starlette>=3.1.1,<3.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue