MCP client support (#427)

- MCP client service
- Tool request/response schema
- API gateway support for mcp-tool
- Message translation for tool request & response
- Make mcp-tool using configuration service for information
  about where the MCP services are.
This commit is contained in:
cybermaggedon 2025-07-07 23:52:23 +01:00 committed by GitHub
parent 21bee4cd83
commit e56186054a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 356 additions and 2 deletions

View file

@ -28,4 +28,5 @@ from . triples_client import TriplesClientSpec
from . document_embeddings_client import DocumentEmbeddingsClientSpec
from . agent_service import AgentService
from . graph_rag_client import GraphRagClientSpec
from . tool_service import ToolService

View file

@ -0,0 +1,121 @@
"""
Tool invocation base class
"""
import json
from prometheus_client import Counter
from .. schema import ToolRequest, ToolResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_concurrency = 1
class ToolService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(ToolService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = ToolRequest,
handler = self.on_request,
concurrency = concurrency,
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = ToolResponse
)
)
if not hasattr(__class__, "tool_invocation_metric"):
__class__.tool_invocation_metric = Counter(
'tool_invocation_count', 'Tool invocation count',
["id", "flow", "name"],
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
response = await self.invoke_tool(
request.name,
json.loads(request.parameters) if request.parameters else {},
)
if isinstance(response, str):
await flow("response").send(
ToolResponse(
error=None,
text=response,
object=None,
),
properties={"id": id}
)
else:
await flow("response").send(
ToolResponse(
error=None,
text=None,
object=json.dumps(response),
),
properties={"id": id}
)
__class__.tool_invocation_metric.labels(
id = self.id, flow = flow.name, name = request.name,
).inc()
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {e}")
print("Send error response...", flush=True)
await flow.producer["response"].send(
ToolResponse(
error=Error(
type = "tool-error",
message = str(e),
),
text=None,
object=None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -16,6 +16,7 @@ from .translators.document_loading import DocumentTranslator, TextDocumentTransl
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
from .translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
@ -88,6 +89,12 @@ TranslatorRegistry.register_service(
PromptResponseTranslator()
)
TranslatorRegistry.register_service(
"tool",
ToolRequestTranslator(),
ToolResponseTranslator()
)
TranslatorRegistry.register_service(
"document-embeddings-query",
DocumentEmbeddingsRequestTranslator(),

View file

@ -0,0 +1,51 @@
import json
from typing import Dict, Any, Tuple
from ...schema import ToolRequest, ToolResponse
from .base import MessageTranslator
class ToolRequestTranslator(MessageTranslator):
"""Translator for ToolRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ToolRequest:
# Handle both "name" and "parameters" input keys
name = data.get("name", "")
if "parameters" in data:
parameters = json.dumps(data["parameters"])
else:
parameters = None
return ToolRequest(
name = name,
parameters = parameters,
)
def from_pulsar(self, obj: ToolRequest) -> Dict[str, Any]:
result = {}
if obj.name:
result["name"] = obj.name
if obj.parameters is not None:
result["parameters"] = json.loads(obj.parameters)
return result
class ToolResponseTranslator(MessageTranslator):
"""Translator for ToolResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ToolResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: ToolResponse) -> Dict[str, Any]:
result = {}
if obj.text:
result["text"] = obj.text
if obj.object:
result["object"] = json.loads(obj.object)
return result
def from_response_with_completion(self, obj: ToolResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -30,3 +30,22 @@ class EmbeddingsResponse(Record):
error = Error()
vectors = Array(Array(Double()))
############################################################################
# Tool request/response
class ToolRequest(Record):
name = String()
# Parameters are JSON encoded
parameters = String()
class ToolResponse(Record):
error = Error()
# Plain text aka "unstructured"
text = String()
# JSON-encoded object aka "structured"
object = String()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.agent.mcp_tool import run
run()

View file

@ -49,6 +49,7 @@ setuptools.setup(
"langchain-community",
"langchain-core",
"langchain-text-splitters",
"mcp",
"minio",
"mistralai",
"neo4j",
@ -99,6 +100,7 @@ setuptools.setup(
"scripts/kg-store",
"scripts/kg-manager",
"scripts/librarian",
"scripts/mcp-tool",
"scripts/metering",
"scripts/object-extract-row",
"scripts/oe-write-milvus",

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,105 @@
"""
MCP tool-calling service, calls an external MCP tool. Input is
name + parameters, output is the response, either a string or an object.
"""
import json
from mcp.client.streamable_http import streamablehttp_client
from mcp import ClientSession
from ... base import ToolService
default_ident = "mcp-tool"
class Service(ToolService):
def __init__(self, **params):
super(Service, self).__init__(
**params
)
self.register_config_handler(self.on_mcp_config)
self.mcp_services = {}
async def on_mcp_config(self, config, version):
print("Got config version", version)
if "mcp" not in config: return
self.mcp_services = {
k: json.loads(v)
for k, v in config["mcp"].items()
}
async def invoke_tool(self, name, parameters):
try:
if name not in self.mcp_services:
raise RuntimeError(f"MCP service {name} not known")
if "url" not in self.mcp_services[name]:
raise RuntimeError(f"MCP service {name} URL not defined")
url = self.mcp_services[name]["url"]
if "name" in self.mcp_services[name]:
remote_name = self.mcp_services[name]["name"]
else:
remote_name = name
print("Invoking", remote_name, "at", url, flush=True)
# Connect to a streamable HTTP server
async with streamablehttp_client(url) as (
read_stream,
write_stream,
_,
):
# Create a session using the client streams
async with ClientSession(read_stream, write_stream) as session:
# Initialize the connection
await session.initialize()
# Call a tool
result = await session.call_tool(
remote_name,
parameters
)
if result.structuredContent:
return result.structuredContent
elif hasattr(result, "content"):
return "".join([
x.text
for x in result.content
])
else:
return "No content"
except BaseExceptionGroup as e:
for child in e.exceptions:
print(child)
raise e.exceptions[0]
except Exception as e:
print(e)
raise e
@staticmethod
def add_args(parser):
ToolService.add_args(parser)
def run():
Service.launch(default_ident, __doc__)

View file

@ -17,7 +17,7 @@ from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . embeddings import EmbeddingsRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . prompt import PromptRequestor
from . mcp_tool import McpToolRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
@ -40,6 +40,7 @@ request_response_dispatchers = {
"agent": AgentRequestor,
"text-completion": TextCompletionRequestor,
"prompt": PromptRequestor,
"mcp-tool": McpToolRequestor,
"graph-rag": GraphRagRequestor,
"document-rag": DocumentRagRequestor,
"embeddings": EmbeddingsRequestor,

View file

@ -0,0 +1,32 @@
from ... schema import ToolRequest, ToolResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class McpToolRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(McpToolRequestor, self).__init__(
pulsar_client=pulsar_client,
request_queue=request_queue,
response_queue=response_queue,
request_schema=ToolRequest,
response_schema=ToolResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("tool")
self.response_translator = TranslatorRegistry.get_response_translator("tool")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -1 +0,0 @@
__version__ = "1.1.0"