refactor(agents): evict mac-only tools/middleware from shared kernel

These were never shared with anonymous_chat (nor podcaster/video_presentation)
-- only multi_agent_chat (subagents/main agent) and the boundary use them:

  shared/tools/mcp/             -> multi_agent_chat/shared/tools/mcp/
  shared/tools/hitl.py          -> multi_agent_chat/shared/tools/hitl.py
  shared/tools/catalog.py       -> multi_agent_chat/shared/tools/catalog.py
  shared/middleware/dedup_tool_calls.py
                                -> multi_agent_chat/shared/middleware/dedup_tool_calls.py

app/agents/shared/ now holds only the genuine anon<->mac kernel:
context, middleware/{compaction,retry_after}, tools/web_search.
This commit is contained in:
CREDO23 2026-06-05 12:50:46 +02:00
parent b7ea829371
commit d59bb2b5aa
21 changed files with 50 additions and 40 deletions

View file

@ -29,7 +29,7 @@ from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from app.agents.shared.middleware.dedup_tool_calls import (
from app.agents.multi_agent_chat.shared.middleware.dedup_tool_calls import (
DedupResolver,
wrap_dedup_key_by_arg_name,
)

View file

@ -9,7 +9,7 @@ factories for those few tools and nothing else, so the main agent's tool
surface stays self-contained and connector-free.
Tool *display* metadata for the whole app (the ``/agent/tools`` listing
endpoint) lives separately in :mod:`app.agents.shared.tools.catalog`, a
endpoint) lives separately in :mod:`app.agents.multi_agent_chat.shared.tools.catalog`, a
pure-data module that imports no connectors. This registry only governs what
the main agent actually builds and binds.
"""

View file

@ -0,0 +1,59 @@
"""Dedup-key resolvers for tool-call deduplication.
A *resolver* maps a tool's ``args`` dict to a stable signature string used to
collapse duplicate calls. These helpers are shared: the MCP tool layer uses
:func:`dedup_key_full_args` as a safe default, and the main-agent
``DedupHITLToolCallsMiddleware`` builds its resolver map from them.
Resolver resolution order (read from each tool's own ``metadata``):
1. ``tool.metadata["dedup_key"]`` callable mapping the args dict to a
stable signature string. This is the canonical mechanism.
2. ``tool.metadata["hitl_dedup_key"]`` string naming a primary arg;
used by MCP / Composio tools that only expose a single key field.
A tool with no resolver from either path simply opts out of dedup.
"""
from __future__ import annotations
import json
from collections.abc import Callable
from typing import Any
# Resolver type — given the tool ``args`` dict returns a stable
# string used to dedupe consecutive calls. ``None`` means no dedup.
DedupResolver = Callable[[dict[str, Any]], str]
def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
"""Adapt a string-arg name into a :data:`DedupResolver`.
Convenience helper for tools that just want to dedupe on a single arg's
lowercased value (the most common case for HITL tools like
``send_gmail_email`` keyed on ``subject``). Set the result on the tool's
``metadata["dedup_key"]``.
"""
def _resolver(args: dict[str, Any]) -> str:
return str(args.get(arg_name, "")).lower()
return _resolver
def dedup_key_full_args(args: dict[str, Any]) -> str:
"""Resolver that collapses calls only when **every** argument is identical.
Safe default for tools where no single field uniquely identifies a call
(e.g. MCP tools whose first required field is a shared workspace id).
"""
try:
return json.dumps(args, sort_keys=True, default=str)
except (TypeError, ValueError):
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
# Backwards-compatible alias for code that imported the original
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
_wrap_string_key = wrap_dedup_key_by_arg_name

View file

@ -0,0 +1 @@
"""Tools shared across multi_agent_chat (main agent + subagents + boundary)."""

View file

@ -0,0 +1,83 @@
"""Pure-data catalog of built-in agent tools.
This module advertises *what* tools exist and their display metadata. It is
intentionally free of any tool implementation imports (no connectors, no
factories) so it can be consumed without pulling the whole tool dependency
graph and so connector packages stay independently deletable.
The single live consumer is the ``GET /agent/tools`` endpoint, which renders
the tool picker in the web UI. Tool *construction* lives elsewhere:
* main-agent tools -> ``app.agents.multi_agent_chat.main_agent.tools.registry``
* subagent / connector tools -> ``app.agents.multi_agent_chat.subagents.*``
"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class ToolMetadata:
"""Display metadata for a single built-in tool.
Attributes:
name: Unique identifier for the tool.
description: Human-readable description of what the tool does.
enabled_by_default: Whether the tool is on when no explicit config
is provided.
hidden: WIP tools that should be excluded from public listings.
"""
name: str
description: str
enabled_by_default: bool = True
hidden: bool = False
# Catalog of all built-in tools. Contributors: add new tools here so they show
# up in the UI tool picker. This list carries metadata only — wire the actual
# implementation in the relevant builder/registry module.
TOOL_CATALOG: list[ToolMetadata] = [
ToolMetadata(name="generate_podcast", description="Generate an audio podcast from provided content"),
ToolMetadata(name="generate_video_presentation", description="Generate a video presentation with slides and narration from provided content"),
ToolMetadata(name="generate_report", description="Generate a structured report from provided content and export it"),
ToolMetadata(name="generate_resume", description="Generate a professional resume as a Typst document"),
ToolMetadata(name="generate_image", description="Generate images from text descriptions using AI image models"),
ToolMetadata(name="scrape_webpage", description="Scrape and extract the main content from a webpage"),
ToolMetadata(name="web_search", description="Search the web for real-time information using configured search engines"),
ToolMetadata(name="create_automation", description="Draft an automation from an NL intent; user approves the card; tool saves"),
ToolMetadata(name="update_memory", description="Save important long-term facts, preferences, and instructions to the (personal or team) memory"),
ToolMetadata(name="create_notion_page", description="Create a new page in the user's Notion workspace"),
ToolMetadata(name="update_notion_page", description="Append new content to an existing Notion page"),
ToolMetadata(name="delete_notion_page", description="Delete an existing Notion page"),
ToolMetadata(name="create_google_drive_file", description="Create a new Google Doc or Google Sheet in Google Drive"),
ToolMetadata(name="delete_google_drive_file", description="Move an indexed Google Drive file to trash"),
ToolMetadata(name="create_dropbox_file", description="Create a new file in Dropbox"),
ToolMetadata(name="delete_dropbox_file", description="Delete a file from Dropbox"),
ToolMetadata(name="create_onedrive_file", description="Create a new file in Microsoft OneDrive"),
ToolMetadata(name="delete_onedrive_file", description="Move a OneDrive file to the recycle bin"),
ToolMetadata(name="search_calendar_events", description="Search Google Calendar events within a date range"),
ToolMetadata(name="create_calendar_event", description="Create a new event on Google Calendar"),
ToolMetadata(name="update_calendar_event", description="Update an existing indexed Google Calendar event"),
ToolMetadata(name="delete_calendar_event", description="Delete an existing indexed Google Calendar event"),
ToolMetadata(name="search_gmail", description="Search emails in Gmail using Gmail search syntax"),
ToolMetadata(name="read_gmail_email", description="Read the full content of a specific Gmail email"),
ToolMetadata(name="create_gmail_draft", description="Create a draft email in Gmail"),
ToolMetadata(name="send_gmail_email", description="Send an email via Gmail"),
ToolMetadata(name="trash_gmail_email", description="Move an indexed email to trash in Gmail"),
ToolMetadata(name="update_gmail_draft", description="Update an existing Gmail draft"),
ToolMetadata(name="create_confluence_page", description="Create a new page in the user's Confluence space"),
ToolMetadata(name="update_confluence_page", description="Update an existing indexed Confluence page"),
ToolMetadata(name="delete_confluence_page", description="Delete an existing indexed Confluence page"),
ToolMetadata(name="list_discord_channels", description="List text channels in the connected Discord server"),
ToolMetadata(name="read_discord_messages", description="Read recent messages from a Discord text channel"),
ToolMetadata(name="send_discord_message", description="Send a message to a Discord text channel"),
ToolMetadata(name="list_teams_channels", description="List Microsoft Teams and their channels"),
ToolMetadata(name="read_teams_messages", description="Read recent messages from a Microsoft Teams channel"),
ToolMetadata(name="send_teams_message", description="Send a message to a Microsoft Teams channel"),
ToolMetadata(name="list_luma_events", description="List upcoming and recent Luma events"),
ToolMetadata(name="read_luma_event", description="Read detailed information about a specific Luma event"),
ToolMetadata(name="create_luma_event", description="Create a new event on Luma"),
]

View file

@ -0,0 +1,187 @@
"""Unified HITL (Human-in-the-Loop) approval utility.
Provides a single ``request_approval()`` function that encapsulates the
interrupt payload creation, decision parsing, and parameter merging logic
shared by every sensitive tool (native connectors and MCP tools alike).
Usage inside a tool::
from app.agents.multi_agent_chat.shared.tools.hitl import request_approval
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": to, "subject": subject, "body": body},
context=context,
)
if result.rejected:
return {"status": "rejected", "message": "User declined."}
# result.params contains the final (possibly edited) parameters
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
from langgraph.types import interrupt
logger = logging.getLogger(__name__)
# Tools that mirror the safety profile of ``write_file`` against the
# SurfSense KB: each call creates ONE artifact in the user's own workspace
# with no external visibility (drafts aren't sent; new files aren't shared
# unless the user shares them later). These are auto-approved by default
# so the agent can compose drafts and seed scratch files without a popup
# on every call.
#
# Members of this set still call ``request_approval`` exactly as before;
# the function returns immediately with ``decision_type="auto_approved"``
# and the original params untouched. This preserves the call-site shape
# (logging, metadata fetching, account fallbacks) so the only behavior
# change is "no interrupt fires".
#
# To re-enable prompting, the future per-search-space rules table
# (``agent_permission_rules``) takes precedence in the permission ruleset
# layering assembled by the agent factory.
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
"update_gmail_draft",
"create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
"create_dropbox_file",
"create_onedrive_file",
}
)
@dataclass(frozen=True, slots=True)
class HITLResult:
"""Outcome of a human-in-the-loop approval request."""
rejected: bool
decision_type: str
params: dict[str, Any] = field(default_factory=dict)
def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]:
"""Extract the first valid decision and its edited parameters.
Returns:
(decision_type, edited_params) where *decision_type* is one of
``"approve"``, ``"edit"``, or ``"reject"`` and *edited_params* is
the dict of user-modified arguments (empty when there are none).
Raises:
ValueError: when no usable decision dict can be found.
"""
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
raise ValueError("No approval decision received")
decision = decisions[0]
decision_type: str = (
decision.get("type") or decision.get("decision_type") or "approve"
)
edited_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
edited_params = edited_args
elif isinstance(decision.get("args"), dict):
edited_params = decision["args"]
return decision_type, edited_params
def request_approval(
*,
action_type: str,
tool_name: str,
params: dict[str, Any],
context: dict[str, Any] | None = None,
trusted_tools: list[str] | None = None,
) -> HITLResult:
"""Pause the graph for user approval and return the decision.
This is a **synchronous** helper (not ``async``) because
``langgraph.types.interrupt`` is itself synchronous it raises a
``GraphInterrupt`` exception that the LangGraph runtime catches.
Parameters
----------
action_type:
A label that the frontend uses to select the correct approval card
(e.g. ``"gmail_email_send"``, ``"mcp_tool_call"``).
tool_name:
The registered LangChain tool name (e.g. ``"send_gmail_email"``).
params:
The original tool arguments. These are shown in the approval card
and used as defaults when the user does not edit anything.
context:
Rich metadata from a ``*ToolMetadataService`` (accounts, folders,
labels, etc.). For MCP tools this can hold the server name and
tool description.
trusted_tools:
An allow-list of tool names the user has previously marked as
"Always Allow". If *tool_name* appears in this list, HITL is
skipped and the tool executes immediately.
Returns
-------
HITLResult
``result.rejected`` is ``True`` when the user chose to deny the
action. Otherwise ``result.params`` contains the final parameter
dict either the originals or the user-edited version merged on
top.
"""
if trusted_tools and tool_name in trusted_tools:
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
# Default policy: low-stakes creation tools (drafts + new-file
# creates) skip HITL because they're as recoverable as a local
# ``write_file`` against the SurfSense KB. The user can still
# delete the artifact in <30s if it's wrong.
logger.info(
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
tool_name,
)
return HITLResult(
rejected=False, decision_type="auto_approved", params=dict(params)
)
approval = interrupt(
{
"type": action_type,
"action": {"tool": tool_name, "params": params},
"context": context or {},
}
)
try:
decision_type, edited_params = _parse_decision(approval)
except ValueError:
logger.warning(
"No approval decision received for %s — rejecting for safety", tool_name
)
return HITLResult(rejected=True, decision_type="error", params=params)
logger.info("User decision for %s: %s", tool_name, decision_type)
if decision_type == "reject":
return HITLResult(rejected=True, decision_type="reject", params=params)
final_params = {**params, **edited_params} if edited_params else dict(params)
return HITLResult(rejected=False, decision_type=decision_type, params=final_params)

View file

@ -0,0 +1,7 @@
"""MCP (Model Context Protocol) integration: client, tool loading, and cache.
Split by responsibility:
- ``client``: the low-level :class:`MCPClient` connection wrapper.
- ``tool``: discovery + LangChain tool construction and cache invalidation.
- ``cache``: the connector tool-cache refresh helpers.
"""

View file

@ -0,0 +1,149 @@
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import SearchSourceConnector, async_session_maker
logger = logging.getLogger(__name__)
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
class CachedMCPToolDef(BaseModel):
name: str
description: str = ""
input_schema: dict[str, Any] = Field(default_factory=dict)
class CachedMCPTools(BaseModel):
discovered_at: datetime
server_version: str | None = None
server_name: str | None = None
transport: str | None = None
tools: list[CachedMCPToolDef]
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
cfg = connector.config or {}
raw = cfg.get("cached_tools")
if not raw or not isinstance(raw, dict):
return None
try:
return CachedMCPTools.model_validate(raw)
except ValidationError as exc:
logger.warning(
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
connector.id,
exc,
)
return None
async def write_cached_tools(
connector_id: int,
tool_definitions: list[dict[str, Any]],
*,
server_name: str | None = None,
server_version: str | None = None,
transport: str | None = None,
) -> None:
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
payload = CachedMCPTools(
discovered_at=datetime.now(UTC),
server_version=server_version,
server_name=server_name,
transport=transport,
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
)
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if connector is None:
return
cfg = dict(connector.config or {})
cfg["cached_tools"] = payload.model_dump(mode="json")
connector.config = cfg
flag_modified(connector, "config")
await session.commit()
logger.info(
"Persisted cached_tools for MCP connector %d (%d tools)",
connector_id,
len(payload.tools),
)
except Exception:
logger.warning(
"Failed to persist cached_tools for MCP connector %d",
connector_id,
exc_info=True,
)
def refresh_mcp_tools_cache_for_connector(
connector_id: int,
search_space_id: int,
) -> None:
"""Maintain the MCP tool cache after a single-connector lifecycle event.
Synchronously evicts the in-process LRU for the connector's search space
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
a background live discovery for this connector alone so its persisted
``cached_tools`` row is refreshed before the next user query.
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
when an event loop is available. Neither path raises.
"""
try:
from app.agents.multi_agent_chat.shared.tools.mcp.tool import (
invalidate_mcp_tools_cache,
)
invalidate_mcp_tools_cache(search_space_id)
except Exception:
logger.debug(
"MCP in-process cache eviction skipped for space %d",
search_space_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
task = loop.create_task(_run_connector_prefetch(connector_id))
_pending_prefetch_tasks.add(task)
task.add_done_callback(_pending_prefetch_tasks.discard)
async def _run_connector_prefetch(connector_id: int) -> None:
from app.agents.multi_agent_chat.shared.tools.mcp.tool import (
discover_single_mcp_connector,
)
try:
await discover_single_mcp_connector(connector_id)
except Exception:
logger.warning(
"MCP background prefetch failed for connector_id=%d",
connector_id,
exc_info=True,
)

View file

@ -0,0 +1,326 @@
"""MCP Client Wrapper.
This module provides a client for communicating with MCP servers via stdio and HTTP transports.
It handles server lifecycle management, tool discovery, and tool execution.
"""
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
logger = logging.getLogger(__name__)
# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(
self, command: str, args: list[str], env: dict[str, str] | None = None
):
"""Initialize MCP client.
Args:
command: Command to spawn the MCP server (e.g., "uvx", "node")
args: Arguments for the command (e.g., ["mcp-server-git"])
env: Optional environment variables for the server process
"""
self.command = command
self.args = args
self.env = env or {}
self.session: ClientSession | None = None
@asynccontextmanager
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Retries only apply to the **connection** phase (spawning the process,
initialising the session). Once the session is yielded to the caller,
any exception raised by the caller propagates normally -- the context
manager will NOT retry after ``yield``.
Previous implementation wrapped both connection AND yield inside the
retry loop. Because ``@asynccontextmanager`` only allows a single
``yield``, a failure after yield caused the generator to attempt a
second yield on retry, triggering
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
the stdio subprocess.
Args:
max_retries: Maximum number of connection retry attempts
Yields:
ClientSession: Active MCP session for making requests
Raises:
RuntimeError: If all connection attempts fail
"""
last_error = None
delay = RETRY_DELAY
connected = False
for attempt in range(max_retries):
try:
server_env = os.environ.copy()
server_env.update(self.env)
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
await session.initialize()
self.session = session
connected = True
if attempt > 0:
logger.info(
"Connected to MCP server on attempt %d: %s %s",
attempt + 1,
self.command,
" ".join(self.args),
)
else:
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
)
try:
yield session
finally:
self.session = None
return
except Exception as e:
self.session = None
if connected:
raise
last_error = e
if attempt < max_retries - 1:
logger.warning(
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
attempt + 1,
max_retries,
e,
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
max_retries,
e,
exc_info=True,
)
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
logger.error(error_msg)
raise RuntimeError(error_msg) from last_error
async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server.
Returns:
List of tool definitions with name, description, and input schema
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
# Call tools/list RPC method
response = await self.session.list_tools()
tools = []
for tool in response.tools:
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("Listed %d tools from MCP server", len(tools))
return tools
except Exception as e:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
timeout: float = 60.0,
) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
timeout: Maximum seconds to wait for the tool to respond
Returns:
Tool execution result
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
logger.info(
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
response = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments=arguments),
timeout=timeout,
)
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except TimeoutError:
logger.error("MCP tool '%s' timed out after %.0fs", tool_name, timeout)
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
except RuntimeError as e:
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.error(
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
)
return f"Error calling tool: {e!s}"
async def test_mcp_connection(
command: str, args: list[str], env: dict[str, str] | None = None
) -> dict[str, Any]:
"""Test connection to an MCP server via stdio and fetch available tools.
Args:
command: Command to spawn the MCP server
args: Arguments for the command
env: Optional environment variables
Returns:
Dict with connection status and available tools
"""
client = MCPClient(command, args, env)
try:
async with client.connect():
tools = await client.list_tools()
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}
async def test_mcp_http_connection(
url: str, headers: dict[str, str] | None = None, transport: str = "streamable-http"
) -> dict[str, Any]:
"""Test connection to an MCP server via HTTP and fetch available tools.
Args:
url: URL of the MCP server
headers: Optional HTTP headers for authentication
transport: Transport type ("streamable-http", "http", or "sse")
Returns:
Dict with connection status and available tools
"""
try:
logger.info(
"Testing HTTP MCP connection to: %s (transport: %s)", url, transport
)
# Use streamable HTTP client for all HTTP-based transports
async with (
streamablehttp_client(url, headers=headers or {}) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
# List available tools
response = await session.list_tools()
tools = []
for tool in response.tools:
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("HTTP MCP connection successful. Found %d tools.", len(tools))
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except Exception as e:
logger.error("Failed to connect to HTTP MCP server: %s", e, exc_info=True)
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}

File diff suppressed because it is too large Load diff

View file

@ -21,7 +21,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.multi_agent_chat.constants import (
CONNECTOR_TYPE_TO_CONNECTOR_AGENT_MAPS,
)
from app.agents.shared.tools.mcp.tool import load_mcp_tools
from app.agents.multi_agent_chat.shared.tools.mcp.tool import load_mcp_tools
from app.db import SearchSourceConnector
logger = logging.getLogger(__name__)