mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): evict mac-only tools/middleware from shared kernel
These were never shared with anonymous_chat (nor podcaster/video_presentation)
-- only multi_agent_chat (subagents/main agent) and the boundary use them:
shared/tools/mcp/ -> multi_agent_chat/shared/tools/mcp/
shared/tools/hitl.py -> multi_agent_chat/shared/tools/hitl.py
shared/tools/catalog.py -> multi_agent_chat/shared/tools/catalog.py
shared/middleware/dedup_tool_calls.py
-> multi_agent_chat/shared/middleware/dedup_tool_calls.py
app/agents/shared/ now holds only the genuine anon<->mac kernel:
context, middleware/{compaction,retry_after}, tools/web_search.
This commit is contained in:
parent
b7ea829371
commit
d59bb2b5aa
21 changed files with 50 additions and 40 deletions
|
|
@ -0,0 +1,59 @@
|
|||
"""Dedup-key resolvers for tool-call deduplication.
|
||||
|
||||
A *resolver* maps a tool's ``args`` dict to a stable signature string used to
|
||||
collapse duplicate calls. These helpers are shared: the MCP tool layer uses
|
||||
:func:`dedup_key_full_args` as a safe default, and the main-agent
|
||||
``DedupHITLToolCallsMiddleware`` builds its resolver map from them.
|
||||
|
||||
Resolver resolution order (read from each tool's own ``metadata``):
|
||||
|
||||
1. ``tool.metadata["dedup_key"]`` — callable mapping the args dict to a
|
||||
stable signature string. This is the canonical mechanism.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string naming a primary arg;
|
||||
used by MCP / Composio tools that only expose a single key field.
|
||||
|
||||
A tool with no resolver from either path simply opts out of dedup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
# Resolver type — given the tool ``args`` dict returns a stable
|
||||
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
||||
DedupResolver = Callable[[dict[str, Any]], str]
|
||||
|
||||
|
||||
def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
||||
"""Adapt a string-arg name into a :data:`DedupResolver`.
|
||||
|
||||
Convenience helper for tools that just want to dedupe on a single arg's
|
||||
lowercased value (the most common case for HITL tools like
|
||||
``send_gmail_email`` keyed on ``subject``). Set the result on the tool's
|
||||
``metadata["dedup_key"]``.
|
||||
"""
|
||||
|
||||
def _resolver(args: dict[str, Any]) -> str:
|
||||
return str(args.get(arg_name, "")).lower()
|
||||
|
||||
return _resolver
|
||||
|
||||
|
||||
def dedup_key_full_args(args: dict[str, Any]) -> str:
|
||||
"""Resolver that collapses calls only when **every** argument is identical.
|
||||
|
||||
Safe default for tools where no single field uniquely identifies a call
|
||||
(e.g. MCP tools whose first required field is a shared workspace id).
|
||||
"""
|
||||
|
||||
try:
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return repr(sorted(args.items())) if isinstance(args, dict) else repr(args)
|
||||
|
||||
|
||||
# Backwards-compatible alias for code that imported the original
|
||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tools shared across multi_agent_chat (main agent + subagents + boundary)."""
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
"""Pure-data catalog of built-in agent tools.
|
||||
|
||||
This module advertises *what* tools exist and their display metadata. It is
|
||||
intentionally free of any tool implementation imports (no connectors, no
|
||||
factories) so it can be consumed without pulling the whole tool dependency
|
||||
graph — and so connector packages stay independently deletable.
|
||||
|
||||
The single live consumer is the ``GET /agent/tools`` endpoint, which renders
|
||||
the tool picker in the web UI. Tool *construction* lives elsewhere:
|
||||
|
||||
* main-agent tools -> ``app.agents.multi_agent_chat.main_agent.tools.registry``
|
||||
* subagent / connector tools -> ``app.agents.multi_agent_chat.subagents.*``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolMetadata:
|
||||
"""Display metadata for a single built-in tool.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the tool.
|
||||
description: Human-readable description of what the tool does.
|
||||
enabled_by_default: Whether the tool is on when no explicit config
|
||||
is provided.
|
||||
hidden: WIP tools that should be excluded from public listings.
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
enabled_by_default: bool = True
|
||||
hidden: bool = False
|
||||
|
||||
|
||||
# Catalog of all built-in tools. Contributors: add new tools here so they show
|
||||
# up in the UI tool picker. This list carries metadata only — wire the actual
|
||||
# implementation in the relevant builder/registry module.
|
||||
TOOL_CATALOG: list[ToolMetadata] = [
|
||||
ToolMetadata(name="generate_podcast", description="Generate an audio podcast from provided content"),
|
||||
ToolMetadata(name="generate_video_presentation", description="Generate a video presentation with slides and narration from provided content"),
|
||||
ToolMetadata(name="generate_report", description="Generate a structured report from provided content and export it"),
|
||||
ToolMetadata(name="generate_resume", description="Generate a professional resume as a Typst document"),
|
||||
ToolMetadata(name="generate_image", description="Generate images from text descriptions using AI image models"),
|
||||
ToolMetadata(name="scrape_webpage", description="Scrape and extract the main content from a webpage"),
|
||||
ToolMetadata(name="web_search", description="Search the web for real-time information using configured search engines"),
|
||||
ToolMetadata(name="create_automation", description="Draft an automation from an NL intent; user approves the card; tool saves"),
|
||||
ToolMetadata(name="update_memory", description="Save important long-term facts, preferences, and instructions to the (personal or team) memory"),
|
||||
ToolMetadata(name="create_notion_page", description="Create a new page in the user's Notion workspace"),
|
||||
ToolMetadata(name="update_notion_page", description="Append new content to an existing Notion page"),
|
||||
ToolMetadata(name="delete_notion_page", description="Delete an existing Notion page"),
|
||||
ToolMetadata(name="create_google_drive_file", description="Create a new Google Doc or Google Sheet in Google Drive"),
|
||||
ToolMetadata(name="delete_google_drive_file", description="Move an indexed Google Drive file to trash"),
|
||||
ToolMetadata(name="create_dropbox_file", description="Create a new file in Dropbox"),
|
||||
ToolMetadata(name="delete_dropbox_file", description="Delete a file from Dropbox"),
|
||||
ToolMetadata(name="create_onedrive_file", description="Create a new file in Microsoft OneDrive"),
|
||||
ToolMetadata(name="delete_onedrive_file", description="Move a OneDrive file to the recycle bin"),
|
||||
ToolMetadata(name="search_calendar_events", description="Search Google Calendar events within a date range"),
|
||||
ToolMetadata(name="create_calendar_event", description="Create a new event on Google Calendar"),
|
||||
ToolMetadata(name="update_calendar_event", description="Update an existing indexed Google Calendar event"),
|
||||
ToolMetadata(name="delete_calendar_event", description="Delete an existing indexed Google Calendar event"),
|
||||
ToolMetadata(name="search_gmail", description="Search emails in Gmail using Gmail search syntax"),
|
||||
ToolMetadata(name="read_gmail_email", description="Read the full content of a specific Gmail email"),
|
||||
ToolMetadata(name="create_gmail_draft", description="Create a draft email in Gmail"),
|
||||
ToolMetadata(name="send_gmail_email", description="Send an email via Gmail"),
|
||||
ToolMetadata(name="trash_gmail_email", description="Move an indexed email to trash in Gmail"),
|
||||
ToolMetadata(name="update_gmail_draft", description="Update an existing Gmail draft"),
|
||||
ToolMetadata(name="create_confluence_page", description="Create a new page in the user's Confluence space"),
|
||||
ToolMetadata(name="update_confluence_page", description="Update an existing indexed Confluence page"),
|
||||
ToolMetadata(name="delete_confluence_page", description="Delete an existing indexed Confluence page"),
|
||||
ToolMetadata(name="list_discord_channels", description="List text channels in the connected Discord server"),
|
||||
ToolMetadata(name="read_discord_messages", description="Read recent messages from a Discord text channel"),
|
||||
ToolMetadata(name="send_discord_message", description="Send a message to a Discord text channel"),
|
||||
ToolMetadata(name="list_teams_channels", description="List Microsoft Teams and their channels"),
|
||||
ToolMetadata(name="read_teams_messages", description="Read recent messages from a Microsoft Teams channel"),
|
||||
ToolMetadata(name="send_teams_message", description="Send a message to a Microsoft Teams channel"),
|
||||
ToolMetadata(name="list_luma_events", description="List upcoming and recent Luma events"),
|
||||
ToolMetadata(name="read_luma_event", description="Read detailed information about a specific Luma event"),
|
||||
ToolMetadata(name="create_luma_event", description="Create a new event on Luma"),
|
||||
]
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
"""Unified HITL (Human-in-the-Loop) approval utility.
|
||||
|
||||
Provides a single ``request_approval()`` function that encapsulates the
|
||||
interrupt payload creation, decision parsing, and parameter merging logic
|
||||
shared by every sensitive tool (native connectors and MCP tools alike).
|
||||
|
||||
Usage inside a tool::
|
||||
|
||||
from app.agents.multi_agent_chat.shared.tools.hitl import request_approval
|
||||
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={"to": to, "subject": subject, "body": body},
|
||||
context=context,
|
||||
)
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined."}
|
||||
# result.params contains the final (possibly edited) parameters
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langgraph.types import interrupt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tools that mirror the safety profile of ``write_file`` against the
|
||||
# SurfSense KB: each call creates ONE artifact in the user's own workspace
|
||||
# with no external visibility (drafts aren't sent; new files aren't shared
|
||||
# unless the user shares them later). These are auto-approved by default
|
||||
# so the agent can compose drafts and seed scratch files without a popup
|
||||
# on every call.
|
||||
#
|
||||
# Members of this set still call ``request_approval`` exactly as before;
|
||||
# the function returns immediately with ``decision_type="auto_approved"``
|
||||
# and the original params untouched. This preserves the call-site shape
|
||||
# (logging, metadata fetching, account fallbacks) so the only behavior
|
||||
# change is "no interrupt fires".
|
||||
#
|
||||
# To re-enable prompting, the future per-search-space rules table
|
||||
# (``agent_permission_rules``) takes precedence in the permission ruleset
|
||||
# layering assembled by the agent factory.
|
||||
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"create_calendar_event",
|
||||
"create_notion_page",
|
||||
"create_confluence_page",
|
||||
"create_google_drive_file",
|
||||
"create_dropbox_file",
|
||||
"create_onedrive_file",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class HITLResult:
|
||||
"""Outcome of a human-in-the-loop approval request."""
|
||||
|
||||
rejected: bool
|
||||
decision_type: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]:
|
||||
"""Extract the first valid decision and its edited parameters.
|
||||
|
||||
Returns:
|
||||
(decision_type, edited_params) where *decision_type* is one of
|
||||
``"approve"``, ``"edit"``, or ``"reject"`` and *edited_params* is
|
||||
the dict of user-modified arguments (empty when there are none).
|
||||
|
||||
Raises:
|
||||
ValueError: when no usable decision dict can be found.
|
||||
"""
|
||||
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
|
||||
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
|
||||
decisions = [d for d in decisions if isinstance(d, dict)]
|
||||
|
||||
if not decisions:
|
||||
raise ValueError("No approval decision received")
|
||||
|
||||
decision = decisions[0]
|
||||
decision_type: str = (
|
||||
decision.get("type") or decision.get("decision_type") or "approve"
|
||||
)
|
||||
|
||||
edited_params: dict[str, Any] = {}
|
||||
edited_action = decision.get("edited_action")
|
||||
if isinstance(edited_action, dict):
|
||||
edited_args = edited_action.get("args")
|
||||
if isinstance(edited_args, dict):
|
||||
edited_params = edited_args
|
||||
elif isinstance(decision.get("args"), dict):
|
||||
edited_params = decision["args"]
|
||||
|
||||
return decision_type, edited_params
|
||||
|
||||
|
||||
def request_approval(
|
||||
*,
|
||||
action_type: str,
|
||||
tool_name: str,
|
||||
params: dict[str, Any],
|
||||
context: dict[str, Any] | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
) -> HITLResult:
|
||||
"""Pause the graph for user approval and return the decision.
|
||||
|
||||
This is a **synchronous** helper (not ``async``) because
|
||||
``langgraph.types.interrupt`` is itself synchronous — it raises a
|
||||
``GraphInterrupt`` exception that the LangGraph runtime catches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action_type:
|
||||
A label that the frontend uses to select the correct approval card
|
||||
(e.g. ``"gmail_email_send"``, ``"mcp_tool_call"``).
|
||||
tool_name:
|
||||
The registered LangChain tool name (e.g. ``"send_gmail_email"``).
|
||||
params:
|
||||
The original tool arguments. These are shown in the approval card
|
||||
and used as defaults when the user does not edit anything.
|
||||
context:
|
||||
Rich metadata from a ``*ToolMetadataService`` (accounts, folders,
|
||||
labels, etc.). For MCP tools this can hold the server name and
|
||||
tool description.
|
||||
trusted_tools:
|
||||
An allow-list of tool names the user has previously marked as
|
||||
"Always Allow". If *tool_name* appears in this list, HITL is
|
||||
skipped and the tool executes immediately.
|
||||
|
||||
Returns
|
||||
-------
|
||||
HITLResult
|
||||
``result.rejected`` is ``True`` when the user chose to deny the
|
||||
action. Otherwise ``result.params`` contains the final parameter
|
||||
dict — either the originals or the user-edited version merged on
|
||||
top.
|
||||
"""
|
||||
if trusted_tools and tool_name in trusted_tools:
|
||||
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
||||
|
||||
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
|
||||
# Default policy: low-stakes creation tools (drafts + new-file
|
||||
# creates) skip HITL because they're as recoverable as a local
|
||||
# ``write_file`` against the SurfSense KB. The user can still
|
||||
# delete the artifact in <30s if it's wrong.
|
||||
logger.info(
|
||||
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
|
||||
tool_name,
|
||||
)
|
||||
return HITLResult(
|
||||
rejected=False, decision_type="auto_approved", params=dict(params)
|
||||
)
|
||||
|
||||
approval = interrupt(
|
||||
{
|
||||
"type": action_type,
|
||||
"action": {"tool": tool_name, "params": params},
|
||||
"context": context or {},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
decision_type, edited_params = _parse_decision(approval)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"No approval decision received for %s — rejecting for safety", tool_name
|
||||
)
|
||||
return HITLResult(rejected=True, decision_type="error", params=params)
|
||||
|
||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||
|
||||
if decision_type == "reject":
|
||||
return HITLResult(rejected=True, decision_type="reject", params=params)
|
||||
|
||||
final_params = {**params, **edited_params} if edited_params else dict(params)
|
||||
return HITLResult(rejected=False, decision_type=decision_type, params=final_params)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""MCP (Model Context Protocol) integration: client, tool loading, and cache.
|
||||
|
||||
Split by responsibility:
|
||||
- ``client``: the low-level :class:`MCPClient` connection wrapper.
|
||||
- ``tool``: discovery + LangChain tool construction and cache invalidation.
|
||||
- ``cache``: the connector tool-cache refresh helpers.
|
||||
"""
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import SearchSourceConnector, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
|
||||
class CachedMCPToolDef(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
input_schema: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CachedMCPTools(BaseModel):
|
||||
discovered_at: datetime
|
||||
server_version: str | None = None
|
||||
server_name: str | None = None
|
||||
transport: str | None = None
|
||||
tools: list[CachedMCPToolDef]
|
||||
|
||||
|
||||
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
|
||||
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
|
||||
cfg = connector.config or {}
|
||||
raw = cfg.get("cached_tools")
|
||||
if not raw or not isinstance(raw, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
return CachedMCPTools.model_validate(raw)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
|
||||
connector.id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def write_cached_tools(
|
||||
connector_id: int,
|
||||
tool_definitions: list[dict[str, Any]],
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
server_version: str | None = None,
|
||||
transport: str | None = None,
|
||||
) -> None:
|
||||
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
|
||||
payload = CachedMCPTools(
|
||||
discovered_at=datetime.now(UTC),
|
||||
server_version=server_version,
|
||||
server_name=server_name,
|
||||
transport=transport,
|
||||
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
|
||||
)
|
||||
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector is None:
|
||||
return
|
||||
|
||||
cfg = dict(connector.config or {})
|
||||
cfg["cached_tools"] = payload.model_dump(mode="json")
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Persisted cached_tools for MCP connector %d (%d tools)",
|
||||
connector_id,
|
||||
len(payload.tools),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist cached_tools for MCP connector %d",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def refresh_mcp_tools_cache_for_connector(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
) -> None:
|
||||
"""Maintain the MCP tool cache after a single-connector lifecycle event.
|
||||
|
||||
Synchronously evicts the in-process LRU for the connector's search space
|
||||
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
|
||||
a background live discovery for this connector alone so its persisted
|
||||
``cached_tools`` row is refreshed before the next user query.
|
||||
|
||||
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
|
||||
when an event loop is available. Neither path raises.
|
||||
"""
|
||||
try:
|
||||
from app.agents.multi_agent_chat.shared.tools.mcp.tool import (
|
||||
invalidate_mcp_tools_cache,
|
||||
)
|
||||
|
||||
invalidate_mcp_tools_cache(search_space_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"MCP in-process cache eviction skipped for space %d",
|
||||
search_space_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
task = loop.create_task(_run_connector_prefetch(connector_id))
|
||||
_pending_prefetch_tasks.add(task)
|
||||
task.add_done_callback(_pending_prefetch_tasks.discard)
|
||||
|
||||
|
||||
async def _run_connector_prefetch(connector_id: int) -> None:
|
||||
from app.agents.multi_agent_chat.shared.tools.mcp.tool import (
|
||||
discover_single_mcp_connector,
|
||||
)
|
||||
|
||||
try:
|
||||
await discover_single_mcp_connector(connector_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"MCP background prefetch failed for connector_id=%d",
|
||||
connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
@ -0,0 +1,326 @@
|
|||
"""MCP Client Wrapper.
|
||||
|
||||
This module provides a client for communicating with MCP servers via stdio and HTTP transports.
|
||||
It handles server lifecycle management, tool discovery, and tool execution.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 1.0 # seconds
|
||||
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Client for communicating with an MCP server."""
|
||||
|
||||
def __init__(
|
||||
self, command: str, args: list[str], env: dict[str, str] | None = None
|
||||
):
|
||||
"""Initialize MCP client.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server (e.g., "uvx", "node")
|
||||
args: Arguments for the command (e.g., ["mcp-server-git"])
|
||||
env: Optional environment variables for the server process
|
||||
|
||||
"""
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.env = env or {}
|
||||
self.session: ClientSession | None = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self, max_retries: int = MAX_RETRIES):
|
||||
"""Connect to the MCP server and manage its lifecycle.
|
||||
|
||||
Retries only apply to the **connection** phase (spawning the process,
|
||||
initialising the session). Once the session is yielded to the caller,
|
||||
any exception raised by the caller propagates normally -- the context
|
||||
manager will NOT retry after ``yield``.
|
||||
|
||||
Previous implementation wrapped both connection AND yield inside the
|
||||
retry loop. Because ``@asynccontextmanager`` only allows a single
|
||||
``yield``, a failure after yield caused the generator to attempt a
|
||||
second yield on retry, triggering
|
||||
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
|
||||
the stdio subprocess.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of connection retry attempts
|
||||
|
||||
Yields:
|
||||
ClientSession: Active MCP session for making requests
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all connection attempts fail
|
||||
|
||||
"""
|
||||
last_error = None
|
||||
delay = RETRY_DELAY
|
||||
connected = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command, args=self.args, env=server_env
|
||||
)
|
||||
|
||||
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
connected = True
|
||||
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
"Connected to MCP server on attempt %d: %s %s",
|
||||
attempt + 1,
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Connected to MCP server: %s %s",
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
self.session = None
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
self.session = None
|
||||
if connected:
|
||||
raise
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay *= RETRY_BACKOFF
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to connect to MCP server after %d attempts: %s",
|
||||
max_retries,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg) from last_error
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]]:
|
||||
"""List all tools available from the MCP server.
|
||||
|
||||
Returns:
|
||||
List of tool definitions with name, description, and input schema
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not connected to server
|
||||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError(
|
||||
"Not connected to MCP server. Use 'async with client.connect():'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Call tools/list RPC method
|
||||
response = await self.session.list_tools()
|
||||
|
||||
tools = []
|
||||
for tool in response.tools:
|
||||
tools.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema
|
||||
if hasattr(tool, "inputSchema")
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Listed %d tools from MCP server", len(tools))
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Arguments to pass to the tool
|
||||
timeout: Maximum seconds to wait for the tool to respond
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not connected to server
|
||||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError(
|
||||
"Not connected to MCP server. Use 'async with client.connect():'"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
self.session.call_tool(tool_name, arguments=arguments),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
result.append(str(content.data))
|
||||
else:
|
||||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||
return result_str
|
||||
|
||||
except TimeoutError:
|
||||
logger.error("MCP tool '%s' timed out after %.0fs", tool_name, timeout)
|
||||
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
|
||||
except RuntimeError as e:
|
||||
if "Invalid structured content" in str(e):
|
||||
logger.warning(
|
||||
"MCP server returned data not matching its schema, but continuing: %s",
|
||||
e,
|
||||
)
|
||||
return "Operation completed (server returned unexpected format)"
|
||||
raise
|
||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||
logger.error(
|
||||
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
|
||||
)
|
||||
return f"Error calling tool: {e!s}"
|
||||
|
||||
|
||||
async def test_mcp_connection(
|
||||
command: str, args: list[str], env: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server via stdio and fetch available tools.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server
|
||||
args: Arguments for the command
|
||||
env: Optional environment variables
|
||||
|
||||
Returns:
|
||||
Dict with connection status and available tools
|
||||
|
||||
"""
|
||||
client = MCPClient(command, args, env)
|
||||
|
||||
try:
|
||||
async with client.connect():
|
||||
tools = await client.list_tools()
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Connected successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
}
|
||||
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
|
||||
async def test_mcp_http_connection(
|
||||
url: str, headers: dict[str, str] | None = None, transport: str = "streamable-http"
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server via HTTP and fetch available tools.
|
||||
|
||||
Args:
|
||||
url: URL of the MCP server
|
||||
headers: Optional HTTP headers for authentication
|
||||
transport: Transport type ("streamable-http", "http", or "sse")
|
||||
|
||||
Returns:
|
||||
Dict with connection status and available tools
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Testing HTTP MCP connection to: %s (transport: %s)", url, transport
|
||||
)
|
||||
|
||||
# Use streamable HTTP client for all HTTP-based transports
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers or {}) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await session.list_tools()
|
||||
tools = []
|
||||
for tool in response.tools:
|
||||
tools.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema
|
||||
if hasattr(tool, "inputSchema")
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("HTTP MCP connection successful. Found %d tools.", len(tools))
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Connected successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to HTTP MCP server: %s", e, exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue