From e56186054a0eb91a92cbd0b19010346b5f2a4eae Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 7 Jul 2025 23:52:23 +0100 Subject: [PATCH] MCP client support (#427) - MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are. --- trustgraph-base/trustgraph/base/__init__.py | 1 + .../trustgraph/base/tool_service.py | 121 ++++++++++++++++++ .../trustgraph/messaging/__init__.py | 7 + .../trustgraph/messaging/translators/tool.py | 51 ++++++++ trustgraph-base/trustgraph/schema/models.py | 19 +++ trustgraph-flow/scripts/mcp-tool | 6 + trustgraph-flow/setup.py | 2 + .../trustgraph/agent/mcp_tool/__init__.py | 3 + .../trustgraph/agent/mcp_tool/__main__.py | 7 + .../trustgraph/agent/mcp_tool/service.py | 105 +++++++++++++++ .../trustgraph/gateway/dispatch/manager.py | 3 +- .../trustgraph/gateway/dispatch/mcp_tool.py | 32 +++++ trustgraph-mcp/trustgraph/mcp_version.py | 1 - 13 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 trustgraph-base/trustgraph/base/tool_service.py create mode 100644 trustgraph-base/trustgraph/messaging/translators/tool.py create mode 100755 trustgraph-flow/scripts/mcp-tool create mode 100644 trustgraph-flow/trustgraph/agent/mcp_tool/__init__.py create mode 100644 trustgraph-flow/trustgraph/agent/mcp_tool/__main__.py create mode 100755 trustgraph-flow/trustgraph/agent/mcp_tool/service.py create mode 100644 trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py delete mode 100644 trustgraph-mcp/trustgraph/mcp_version.py diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 2accbb21..24b10390 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -28,4 +28,5 @@ from . triples_client import TriplesClientSpec from . document_embeddings_client import DocumentEmbeddingsClientSpec from . agent_service import AgentService from . graph_rag_client import GraphRagClientSpec +from . tool_service import ToolService diff --git a/trustgraph-base/trustgraph/base/tool_service.py b/trustgraph-base/trustgraph/base/tool_service.py new file mode 100644 index 00000000..4f63bc53 --- /dev/null +++ b/trustgraph-base/trustgraph/base/tool_service.py @@ -0,0 +1,121 @@ + +""" +Tool invocation base class +""" + +import json +from prometheus_client import Counter + +from .. schema import ToolRequest, ToolResponse, Error +from .. exceptions import TooManyRequests +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec + +default_concurrency = 1 + +class ToolService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + concurrency = params.get("concurrency", 1) + + super(ToolService, self).__init__(**params | { + "id": id, + "concurrency": concurrency, + }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = ToolRequest, + handler = self.on_request, + concurrency = concurrency, + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = ToolResponse + ) + ) + + if not hasattr(__class__, "tool_invocation_metric"): + __class__.tool_invocation_metric = Counter( + 'tool_invocation_count', 'Tool invocation count', + ["id", "flow", "name"], + ) + + async def on_request(self, msg, consumer, flow): + + try: + + request = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + response = await self.invoke_tool( + request.name, + json.loads(request.parameters) if request.parameters else {}, + ) + + if isinstance(response, str): + await flow("response").send( + ToolResponse( + error=None, + text=response, + object=None, + ), + properties={"id": id} + ) + else: + await flow("response").send( + ToolResponse( + error=None, + text=None, + object=json.dumps(response), + ), + properties={"id": id} + ) + + __class__.tool_invocation_metric.labels( + id = self.id, flow = flow.name, name = request.name, + ).inc() + + except TooManyRequests as e: + raise e + + except Exception as e: + + # Apart from rate limits, treat all exceptions as unrecoverable + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + await flow.producer["response"].send( + ToolResponse( + error=Error( + type = "tool-error", + message = str(e), + ), + text=None, + object=None, + ), + properties={"id": id} + ) + + @staticmethod + def add_args(parser): + + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + + FlowProcessor.add_args(parser) + diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index a9caf950..1ed89be7 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -16,6 +16,7 @@ from .translators.document_loading import DocumentTranslator, TextDocumentTransl from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator from .translators.flow import FlowRequestTranslator, FlowResponseTranslator from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator +from .translators.tool import ToolRequestTranslator, ToolResponseTranslator from .translators.embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator @@ -88,6 +89,12 @@ TranslatorRegistry.register_service( PromptResponseTranslator() ) +TranslatorRegistry.register_service( + "tool", + ToolRequestTranslator(), + ToolResponseTranslator() +) + TranslatorRegistry.register_service( "document-embeddings-query", DocumentEmbeddingsRequestTranslator(), diff --git a/trustgraph-base/trustgraph/messaging/translators/tool.py b/trustgraph-base/trustgraph/messaging/translators/tool.py new file mode 100644 index 00000000..9f4d05cc --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/tool.py @@ -0,0 +1,51 @@ +import json +from typing import Dict, Any, Tuple +from ...schema import ToolRequest, ToolResponse +from .base import MessageTranslator + +class ToolRequestTranslator(MessageTranslator): + """Translator for ToolRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> ToolRequest: + # Handle both "name" and "parameters" input keys + name = data.get("name", "") + if "parameters" in data: + parameters = json.dumps(data["parameters"]) + else: + parameters = None + + return ToolRequest( + name = name, + parameters = parameters, + ) + + def from_pulsar(self, obj: ToolRequest) -> Dict[str, Any]: + result = {} + + if obj.name: + result["name"] = obj.name + if obj.parameters is not None: + result["parameters"] = json.loads(obj.parameters) + + return result + +class ToolResponseTranslator(MessageTranslator): + """Translator for ToolResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> ToolResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: ToolResponse) -> Dict[str, Any]: + + result = {} + + if obj.text: + result["text"] = obj.text + if obj.object: + result["object"] = json.loads(obj.object) + + return result + + def from_response_with_completion(self, obj: ToolResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True diff --git a/trustgraph-base/trustgraph/schema/models.py b/trustgraph-base/trustgraph/schema/models.py index ea3b9128..a3b37e4e 100644 --- a/trustgraph-base/trustgraph/schema/models.py +++ b/trustgraph-base/trustgraph/schema/models.py @@ -30,3 +30,22 @@ class EmbeddingsResponse(Record): error = Error() vectors = Array(Array(Double())) +############################################################################ + +# Tool request/response + +class ToolRequest(Record): + name = String() + + # Parameters are JSON encoded + parameters = String() + +class ToolResponse(Record): + error = Error() + + # Plain text aka "unstructured" + text = String() + + # JSON-encoded object aka "structured" + object = String() + diff --git a/trustgraph-flow/scripts/mcp-tool b/trustgraph-flow/scripts/mcp-tool new file mode 100755 index 00000000..369df360 --- /dev/null +++ b/trustgraph-flow/scripts/mcp-tool @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.agent.mcp_tool import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index 0f025894..5e8066f9 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -49,6 +49,7 @@ setuptools.setup( "langchain-community", "langchain-core", "langchain-text-splitters", + "mcp", "minio", "mistralai", "neo4j", @@ -99,6 +100,7 @@ setuptools.setup( "scripts/kg-store", "scripts/kg-manager", "scripts/librarian", + "scripts/mcp-tool", "scripts/metering", "scripts/object-extract-row", "scripts/oe-write-milvus", diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/__init__.py b/trustgraph-flow/trustgraph/agent/mcp_tool/__init__.py new file mode 100644 index 00000000..ba844705 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/__init__.py @@ -0,0 +1,3 @@ + +from . service import * + diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/__main__.py b/trustgraph-flow/trustgraph/agent/mcp_tool/__main__.py new file mode 100644 index 00000000..e9136855 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py new file mode 100755 index 00000000..b20f26b5 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -0,0 +1,105 @@ + +""" +MCP tool-calling service, calls an external MCP tool. Input is +name + parameters, output is the response, either a string or an object. +""" + +import json +from mcp.client.streamable_http import streamablehttp_client +from mcp import ClientSession + +from ... base import ToolService + +default_ident = "mcp-tool" + +class Service(ToolService): + + def __init__(self, **params): + + super(Service, self).__init__( + **params + ) + + self.register_config_handler(self.on_mcp_config) + + self.mcp_services = {} + + async def on_mcp_config(self, config, version): + + print("Got config version", version) + + if "mcp" not in config: return + + self.mcp_services = { + k: json.loads(v) + for k, v in config["mcp"].items() + } + + async def invoke_tool(self, name, parameters): + + try: + + if name not in self.mcp_services: + raise RuntimeError(f"MCP service {name} not known") + + if "url" not in self.mcp_services[name]: + raise RuntimeError(f"MCP service {name} URL not defined") + + url = self.mcp_services[name]["url"] + + if "name" in self.mcp_services[name]: + remote_name = self.mcp_services[name]["name"] + else: + remote_name = name + + print("Invoking", remote_name, "at", url, flush=True) + + # Connect to a streamable HTTP server + async with streamablehttp_client(url) as ( + read_stream, + write_stream, + _, + ): + + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + + # Initialize the connection + await session.initialize() + + # Call a tool + result = await session.call_tool( + remote_name, + parameters + ) + + if result.structuredContent: + return result.structuredContent + elif hasattr(result, "content"): + return "".join([ + x.text + for x in result.content + ]) + else: + return "No content" + + except BaseExceptionGroup as e: + + for child in e.exceptions: + print(child) + + raise e.exceptions[0] + + except Exception as e: + + print(e) + raise e + + @staticmethod + def add_args(parser): + + ToolService.add_args(parser) + +def run(): + Service.launch(default_ident, __doc__) + diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 0b5b26f1..b32a6253 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -17,7 +17,7 @@ from . document_rag import DocumentRagRequestor from . triples_query import TriplesQueryRequestor from . embeddings import EmbeddingsRequestor from . graph_embeddings_query import GraphEmbeddingsQueryRequestor -from . prompt import PromptRequestor +from . mcp_tool import McpToolRequestor from . text_load import TextLoad from . document_load import DocumentLoad @@ -40,6 +40,7 @@ request_response_dispatchers = { "agent": AgentRequestor, "text-completion": TextCompletionRequestor, "prompt": PromptRequestor, + "mcp-tool": McpToolRequestor, "graph-rag": GraphRagRequestor, "document-rag": DocumentRagRequestor, "embeddings": EmbeddingsRequestor, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py b/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py new file mode 100644 index 00000000..da2a7bb0 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py @@ -0,0 +1,32 @@ + +from ... schema import ToolRequest, ToolResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class McpToolRequestor(ServiceRequestor): + def __init__( + self, pulsar_client, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(McpToolRequestor, self).__init__( + pulsar_client=pulsar_client, + request_queue=request_queue, + response_queue=response_queue, + request_schema=ToolRequest, + response_schema=ToolResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("tool") + self.response_translator = TranslatorRegistry.get_response_translator("tool") + + def to_request(self, body): + return self.request_translator.to_pulsar(body) + + def from_response(self, message): + return self.response_translator.from_response_with_completion(message) + diff --git a/trustgraph-mcp/trustgraph/mcp_version.py b/trustgraph-mcp/trustgraph/mcp_version.py deleted file mode 100644 index 6849410a..00000000 --- a/trustgraph-mcp/trustgraph/mcp_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "1.1.0"