mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 09:26:22 +02:00
MCP client support (#427)
- MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are.
This commit is contained in:
parent
21bee4cd83
commit
e56186054a
13 changed files with 356 additions and 2 deletions
|
|
@ -28,4 +28,5 @@ from . triples_client import TriplesClientSpec
|
||||||
from . document_embeddings_client import DocumentEmbeddingsClientSpec
|
from . document_embeddings_client import DocumentEmbeddingsClientSpec
|
||||||
from . agent_service import AgentService
|
from . agent_service import AgentService
|
||||||
from . graph_rag_client import GraphRagClientSpec
|
from . graph_rag_client import GraphRagClientSpec
|
||||||
|
from . tool_service import ToolService
|
||||||
|
|
||||||
|
|
|
||||||
121
trustgraph-base/trustgraph/base/tool_service.py
Normal file
121
trustgraph-base/trustgraph/base/tool_service.py
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tool invocation base class
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
from .. schema import ToolRequest, ToolResponse, Error
|
||||||
|
from .. exceptions import TooManyRequests
|
||||||
|
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
|
|
||||||
|
default_concurrency = 1
|
||||||
|
|
||||||
|
class ToolService(FlowProcessor):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
id = params.get("id")
|
||||||
|
concurrency = params.get("concurrency", 1)
|
||||||
|
|
||||||
|
super(ToolService, self).__init__(**params | {
|
||||||
|
"id": id,
|
||||||
|
"concurrency": concurrency,
|
||||||
|
})
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ConsumerSpec(
|
||||||
|
name = "request",
|
||||||
|
schema = ToolRequest,
|
||||||
|
handler = self.on_request,
|
||||||
|
concurrency = concurrency,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ProducerSpec(
|
||||||
|
name = "response",
|
||||||
|
schema = ToolResponse
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hasattr(__class__, "tool_invocation_metric"):
|
||||||
|
__class__.tool_invocation_metric = Counter(
|
||||||
|
'tool_invocation_count', 'Tool invocation count',
|
||||||
|
["id", "flow", "name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_request(self, msg, consumer, flow):
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
request = msg.value()
|
||||||
|
|
||||||
|
# Sender-produced ID
|
||||||
|
|
||||||
|
id = msg.properties()["id"]
|
||||||
|
|
||||||
|
response = await self.invoke_tool(
|
||||||
|
request.name,
|
||||||
|
json.loads(request.parameters) if request.parameters else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, str):
|
||||||
|
await flow("response").send(
|
||||||
|
ToolResponse(
|
||||||
|
error=None,
|
||||||
|
text=response,
|
||||||
|
object=None,
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await flow("response").send(
|
||||||
|
ToolResponse(
|
||||||
|
error=None,
|
||||||
|
text=None,
|
||||||
|
object=json.dumps(response),
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
|
__class__.tool_invocation_metric.labels(
|
||||||
|
id = self.id, flow = flow.name, name = request.name,
|
||||||
|
).inc()
|
||||||
|
|
||||||
|
except TooManyRequests as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||||
|
|
||||||
|
print(f"Exception: {e}")
|
||||||
|
|
||||||
|
print("Send error response...", flush=True)
|
||||||
|
|
||||||
|
await flow.producer["response"].send(
|
||||||
|
ToolResponse(
|
||||||
|
error=Error(
|
||||||
|
type = "tool-error",
|
||||||
|
message = str(e),
|
||||||
|
),
|
||||||
|
text=None,
|
||||||
|
object=None,
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--concurrency',
|
||||||
|
type=int,
|
||||||
|
default=default_concurrency,
|
||||||
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||||
|
)
|
||||||
|
|
||||||
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
|
|
@ -16,6 +16,7 @@ from .translators.document_loading import DocumentTranslator, TextDocumentTransl
|
||||||
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
|
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
|
||||||
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
|
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
|
||||||
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
|
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||||
|
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
||||||
from .translators.embeddings_query import (
|
from .translators.embeddings_query import (
|
||||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||||
|
|
@ -88,6 +89,12 @@ TranslatorRegistry.register_service(
|
||||||
PromptResponseTranslator()
|
PromptResponseTranslator()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"tool",
|
||||||
|
ToolRequestTranslator(),
|
||||||
|
ToolResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
TranslatorRegistry.register_service(
|
TranslatorRegistry.register_service(
|
||||||
"document-embeddings-query",
|
"document-embeddings-query",
|
||||||
DocumentEmbeddingsRequestTranslator(),
|
DocumentEmbeddingsRequestTranslator(),
|
||||||
|
|
|
||||||
51
trustgraph-base/trustgraph/messaging/translators/tool.py
Normal file
51
trustgraph-base/trustgraph/messaging/translators/tool.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import ToolRequest, ToolResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
class ToolRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for ToolRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> ToolRequest:
|
||||||
|
# Handle both "name" and "parameters" input keys
|
||||||
|
name = data.get("name", "")
|
||||||
|
if "parameters" in data:
|
||||||
|
parameters = json.dumps(data["parameters"])
|
||||||
|
else:
|
||||||
|
parameters = None
|
||||||
|
|
||||||
|
return ToolRequest(
|
||||||
|
name = name,
|
||||||
|
parameters = parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: ToolRequest) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.name:
|
||||||
|
result["name"] = obj.name
|
||||||
|
if obj.parameters is not None:
|
||||||
|
result["parameters"] = json.loads(obj.parameters)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
class ToolResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for ToolResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> ToolResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: ToolResponse) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.text:
|
||||||
|
result["text"] = obj.text
|
||||||
|
if obj.object:
|
||||||
|
result["object"] = json.loads(obj.object)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: ToolResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -30,3 +30,22 @@ class EmbeddingsResponse(Record):
|
||||||
error = Error()
|
error = Error()
|
||||||
vectors = Array(Array(Double()))
|
vectors = Array(Array(Double()))
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
|
||||||
|
# Tool request/response
|
||||||
|
|
||||||
|
class ToolRequest(Record):
|
||||||
|
name = String()
|
||||||
|
|
||||||
|
# Parameters are JSON encoded
|
||||||
|
parameters = String()
|
||||||
|
|
||||||
|
class ToolResponse(Record):
|
||||||
|
error = Error()
|
||||||
|
|
||||||
|
# Plain text aka "unstructured"
|
||||||
|
text = String()
|
||||||
|
|
||||||
|
# JSON-encoded object aka "structured"
|
||||||
|
object = String()
|
||||||
|
|
||||||
|
|
|
||||||
6
trustgraph-flow/scripts/mcp-tool
Executable file
6
trustgraph-flow/scripts/mcp-tool
Executable file
|
|
@ -0,0 +1,6 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from trustgraph.agent.mcp_tool import run
|
||||||
|
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
@ -49,6 +49,7 @@ setuptools.setup(
|
||||||
"langchain-community",
|
"langchain-community",
|
||||||
"langchain-core",
|
"langchain-core",
|
||||||
"langchain-text-splitters",
|
"langchain-text-splitters",
|
||||||
|
"mcp",
|
||||||
"minio",
|
"minio",
|
||||||
"mistralai",
|
"mistralai",
|
||||||
"neo4j",
|
"neo4j",
|
||||||
|
|
@ -99,6 +100,7 @@ setuptools.setup(
|
||||||
"scripts/kg-store",
|
"scripts/kg-store",
|
||||||
"scripts/kg-manager",
|
"scripts/kg-manager",
|
||||||
"scripts/librarian",
|
"scripts/librarian",
|
||||||
|
"scripts/mcp-tool",
|
||||||
"scripts/metering",
|
"scripts/metering",
|
||||||
"scripts/object-extract-row",
|
"scripts/object-extract-row",
|
||||||
"scripts/oe-write-milvus",
|
"scripts/oe-write-milvus",
|
||||||
|
|
|
||||||
3
trustgraph-flow/trustgraph/agent/mcp_tool/__init__.py
Normal file
3
trustgraph-flow/trustgraph/agent/mcp_tool/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
|
||||||
|
from . service import *
|
||||||
|
|
||||||
7
trustgraph-flow/trustgraph/agent/mcp_tool/__main__.py
Normal file
7
trustgraph-flow/trustgraph/agent/mcp_tool/__main__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from . service import run
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run()
|
||||||
|
|
||||||
105
trustgraph-flow/trustgraph/agent/mcp_tool/service.py
Executable file
105
trustgraph-flow/trustgraph/agent/mcp_tool/service.py
Executable file
|
|
@ -0,0 +1,105 @@
|
||||||
|
|
||||||
|
"""
|
||||||
|
MCP tool-calling service, calls an external MCP tool. Input is
|
||||||
|
name + parameters, output is the response, either a string or an object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
from mcp import ClientSession
|
||||||
|
|
||||||
|
from ... base import ToolService
|
||||||
|
|
||||||
|
default_ident = "mcp-tool"
|
||||||
|
|
||||||
|
class Service(ToolService):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
super(Service, self).__init__(
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_config_handler(self.on_mcp_config)
|
||||||
|
|
||||||
|
self.mcp_services = {}
|
||||||
|
|
||||||
|
async def on_mcp_config(self, config, version):
|
||||||
|
|
||||||
|
print("Got config version", version)
|
||||||
|
|
||||||
|
if "mcp" not in config: return
|
||||||
|
|
||||||
|
self.mcp_services = {
|
||||||
|
k: json.loads(v)
|
||||||
|
for k, v in config["mcp"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def invoke_tool(self, name, parameters):
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
if name not in self.mcp_services:
|
||||||
|
raise RuntimeError(f"MCP service {name} not known")
|
||||||
|
|
||||||
|
if "url" not in self.mcp_services[name]:
|
||||||
|
raise RuntimeError(f"MCP service {name} URL not defined")
|
||||||
|
|
||||||
|
url = self.mcp_services[name]["url"]
|
||||||
|
|
||||||
|
if "name" in self.mcp_services[name]:
|
||||||
|
remote_name = self.mcp_services[name]["name"]
|
||||||
|
else:
|
||||||
|
remote_name = name
|
||||||
|
|
||||||
|
print("Invoking", remote_name, "at", url, flush=True)
|
||||||
|
|
||||||
|
# Connect to a streamable HTTP server
|
||||||
|
async with streamablehttp_client(url) as (
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
_,
|
||||||
|
):
|
||||||
|
|
||||||
|
# Create a session using the client streams
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
|
||||||
|
# Initialize the connection
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
# Call a tool
|
||||||
|
result = await session.call_tool(
|
||||||
|
remote_name,
|
||||||
|
parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.structuredContent:
|
||||||
|
return result.structuredContent
|
||||||
|
elif hasattr(result, "content"):
|
||||||
|
return "".join([
|
||||||
|
x.text
|
||||||
|
for x in result.content
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
return "No content"
|
||||||
|
|
||||||
|
except BaseExceptionGroup as e:
|
||||||
|
|
||||||
|
for child in e.exceptions:
|
||||||
|
print(child)
|
||||||
|
|
||||||
|
raise e.exceptions[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
|
||||||
|
ToolService.add_args(parser)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
Service.launch(default_ident, __doc__)
|
||||||
|
|
||||||
|
|
@ -17,7 +17,7 @@ from . document_rag import DocumentRagRequestor
|
||||||
from . triples_query import TriplesQueryRequestor
|
from . triples_query import TriplesQueryRequestor
|
||||||
from . embeddings import EmbeddingsRequestor
|
from . embeddings import EmbeddingsRequestor
|
||||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||||
from . prompt import PromptRequestor
|
from . mcp_tool import McpToolRequestor
|
||||||
from . text_load import TextLoad
|
from . text_load import TextLoad
|
||||||
from . document_load import DocumentLoad
|
from . document_load import DocumentLoad
|
||||||
|
|
||||||
|
|
@ -40,6 +40,7 @@ request_response_dispatchers = {
|
||||||
"agent": AgentRequestor,
|
"agent": AgentRequestor,
|
||||||
"text-completion": TextCompletionRequestor,
|
"text-completion": TextCompletionRequestor,
|
||||||
"prompt": PromptRequestor,
|
"prompt": PromptRequestor,
|
||||||
|
"mcp-tool": McpToolRequestor,
|
||||||
"graph-rag": GraphRagRequestor,
|
"graph-rag": GraphRagRequestor,
|
||||||
"document-rag": DocumentRagRequestor,
|
"document-rag": DocumentRagRequestor,
|
||||||
"embeddings": EmbeddingsRequestor,
|
"embeddings": EmbeddingsRequestor,
|
||||||
|
|
|
||||||
32
trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py
Normal file
32
trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
|
||||||
|
from ... schema import ToolRequest, ToolResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
class McpToolRequestor(ServiceRequestor):
|
||||||
|
def __init__(
|
||||||
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
|
consumer, subscriber,
|
||||||
|
):
|
||||||
|
|
||||||
|
super(McpToolRequestor, self).__init__(
|
||||||
|
pulsar_client=pulsar_client,
|
||||||
|
request_queue=request_queue,
|
||||||
|
response_queue=response_queue,
|
||||||
|
request_schema=ToolRequest,
|
||||||
|
response_schema=ToolResponse,
|
||||||
|
subscription = subscriber,
|
||||||
|
consumer_name = consumer,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("tool")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("tool")
|
||||||
|
|
||||||
|
def to_request(self, body):
|
||||||
|
return self.request_translator.to_pulsar(body)
|
||||||
|
|
||||||
|
def from_response(self, message):
|
||||||
|
return self.response_translator.from_response_with_completion(message)
|
||||||
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
__version__ = "1.1.0"
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue