MCP client support (#427)

- MCP client service
- Tool request/response schema
- API gateway support for mcp-tool
- Message translation for tool request & response
- Make mcp-tool using configuration service for information
  about where the MCP services are.
This commit is contained in:
cybermaggedon 2025-07-07 23:52:23 +01:00 committed by GitHub
parent 21bee4cd83
commit e56186054a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 356 additions and 2 deletions

View file

@ -28,4 +28,5 @@ from . triples_client import TriplesClientSpec
from . document_embeddings_client import DocumentEmbeddingsClientSpec from . document_embeddings_client import DocumentEmbeddingsClientSpec
from . agent_service import AgentService from . agent_service import AgentService
from . graph_rag_client import GraphRagClientSpec from . graph_rag_client import GraphRagClientSpec
from . tool_service import ToolService

View file

@ -0,0 +1,121 @@
"""
Tool invocation base class
"""
import json
from prometheus_client import Counter
from .. schema import ToolRequest, ToolResponse, Error
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_concurrency = 1
class ToolService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(ToolService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = ToolRequest,
handler = self.on_request,
concurrency = concurrency,
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = ToolResponse
)
)
if not hasattr(__class__, "tool_invocation_metric"):
__class__.tool_invocation_metric = Counter(
'tool_invocation_count', 'Tool invocation count',
["id", "flow", "name"],
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
response = await self.invoke_tool(
request.name,
json.loads(request.parameters) if request.parameters else {},
)
if isinstance(response, str):
await flow("response").send(
ToolResponse(
error=None,
text=response,
object=None,
),
properties={"id": id}
)
else:
await flow("response").send(
ToolResponse(
error=None,
text=None,
object=json.dumps(response),
),
properties={"id": id}
)
__class__.tool_invocation_metric.labels(
id = self.id, flow = flow.name, name = request.name,
).inc()
except TooManyRequests as e:
raise e
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {e}")
print("Send error response...", flush=True)
await flow.producer["response"].send(
ToolResponse(
error=Error(
type = "tool-error",
message = str(e),
),
text=None,
object=None,
),
properties={"id": id}
)
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -16,6 +16,7 @@ from .translators.document_loading import DocumentTranslator, TextDocumentTransl
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
from .translators.embeddings_query import ( from .translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
@ -88,6 +89,12 @@ TranslatorRegistry.register_service(
PromptResponseTranslator() PromptResponseTranslator()
) )
TranslatorRegistry.register_service(
"tool",
ToolRequestTranslator(),
ToolResponseTranslator()
)
TranslatorRegistry.register_service( TranslatorRegistry.register_service(
"document-embeddings-query", "document-embeddings-query",
DocumentEmbeddingsRequestTranslator(), DocumentEmbeddingsRequestTranslator(),

View file

@ -0,0 +1,51 @@
import json
from typing import Dict, Any, Tuple
from ...schema import ToolRequest, ToolResponse
from .base import MessageTranslator
class ToolRequestTranslator(MessageTranslator):
"""Translator for ToolRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ToolRequest:
# Handle both "name" and "parameters" input keys
name = data.get("name", "")
if "parameters" in data:
parameters = json.dumps(data["parameters"])
else:
parameters = None
return ToolRequest(
name = name,
parameters = parameters,
)
def from_pulsar(self, obj: ToolRequest) -> Dict[str, Any]:
result = {}
if obj.name:
result["name"] = obj.name
if obj.parameters is not None:
result["parameters"] = json.loads(obj.parameters)
return result
class ToolResponseTranslator(MessageTranslator):
"""Translator for ToolResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ToolResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: ToolResponse) -> Dict[str, Any]:
result = {}
if obj.text:
result["text"] = obj.text
if obj.object:
result["object"] = json.loads(obj.object)
return result
def from_response_with_completion(self, obj: ToolResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -30,3 +30,22 @@ class EmbeddingsResponse(Record):
error = Error() error = Error()
vectors = Array(Array(Double())) vectors = Array(Array(Double()))
############################################################################
# Tool request/response
class ToolRequest(Record):
name = String()
# Parameters are JSON encoded
parameters = String()
class ToolResponse(Record):
error = Error()
# Plain text aka "unstructured"
text = String()
# JSON-encoded object aka "structured"
object = String()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.agent.mcp_tool import run
run()

View file

@ -49,6 +49,7 @@ setuptools.setup(
"langchain-community", "langchain-community",
"langchain-core", "langchain-core",
"langchain-text-splitters", "langchain-text-splitters",
"mcp",
"minio", "minio",
"mistralai", "mistralai",
"neo4j", "neo4j",
@ -99,6 +100,7 @@ setuptools.setup(
"scripts/kg-store", "scripts/kg-store",
"scripts/kg-manager", "scripts/kg-manager",
"scripts/librarian", "scripts/librarian",
"scripts/mcp-tool",
"scripts/metering", "scripts/metering",
"scripts/object-extract-row", "scripts/object-extract-row",
"scripts/oe-write-milvus", "scripts/oe-write-milvus",

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,105 @@
"""
MCP tool-calling service, calls an external MCP tool. Input is
name + parameters, output is the response, either a string or an object.
"""
import json
from mcp.client.streamable_http import streamablehttp_client
from mcp import ClientSession
from ... base import ToolService
default_ident = "mcp-tool"
class Service(ToolService):
def __init__(self, **params):
super(Service, self).__init__(
**params
)
self.register_config_handler(self.on_mcp_config)
self.mcp_services = {}
async def on_mcp_config(self, config, version):
print("Got config version", version)
if "mcp" not in config: return
self.mcp_services = {
k: json.loads(v)
for k, v in config["mcp"].items()
}
async def invoke_tool(self, name, parameters):
try:
if name not in self.mcp_services:
raise RuntimeError(f"MCP service {name} not known")
if "url" not in self.mcp_services[name]:
raise RuntimeError(f"MCP service {name} URL not defined")
url = self.mcp_services[name]["url"]
if "name" in self.mcp_services[name]:
remote_name = self.mcp_services[name]["name"]
else:
remote_name = name
print("Invoking", remote_name, "at", url, flush=True)
# Connect to a streamable HTTP server
async with streamablehttp_client(url) as (
read_stream,
write_stream,
_,
):
# Create a session using the client streams
async with ClientSession(read_stream, write_stream) as session:
# Initialize the connection
await session.initialize()
# Call a tool
result = await session.call_tool(
remote_name,
parameters
)
if result.structuredContent:
return result.structuredContent
elif hasattr(result, "content"):
return "".join([
x.text
for x in result.content
])
else:
return "No content"
except BaseExceptionGroup as e:
for child in e.exceptions:
print(child)
raise e.exceptions[0]
except Exception as e:
print(e)
raise e
@staticmethod
def add_args(parser):
ToolService.add_args(parser)
def run():
Service.launch(default_ident, __doc__)

View file

@ -17,7 +17,7 @@ from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor from . triples_query import TriplesQueryRequestor
from . embeddings import EmbeddingsRequestor from . embeddings import EmbeddingsRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . prompt import PromptRequestor from . mcp_tool import McpToolRequestor
from . text_load import TextLoad from . text_load import TextLoad
from . document_load import DocumentLoad from . document_load import DocumentLoad
@ -40,6 +40,7 @@ request_response_dispatchers = {
"agent": AgentRequestor, "agent": AgentRequestor,
"text-completion": TextCompletionRequestor, "text-completion": TextCompletionRequestor,
"prompt": PromptRequestor, "prompt": PromptRequestor,
"mcp-tool": McpToolRequestor,
"graph-rag": GraphRagRequestor, "graph-rag": GraphRagRequestor,
"document-rag": DocumentRagRequestor, "document-rag": DocumentRagRequestor,
"embeddings": EmbeddingsRequestor, "embeddings": EmbeddingsRequestor,

View file

@ -0,0 +1,32 @@
from ... schema import ToolRequest, ToolResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class McpToolRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(McpToolRequestor, self).__init__(
pulsar_client=pulsar_client,
request_queue=request_queue,
response_queue=response_queue,
request_schema=ToolRequest,
response_schema=ToolResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("tool")
self.response_translator = TranslatorRegistry.get_response_translator("tool")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -1 +0,0 @@
__version__ = "1.1.0"