mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 00:46:22 +02:00
MCP client support (#427)
- MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are.
This commit is contained in:
parent
21bee4cd83
commit
e56186054a
13 changed files with 356 additions and 2 deletions
|
|
@ -28,4 +28,5 @@ from . triples_client import TriplesClientSpec
|
|||
from . document_embeddings_client import DocumentEmbeddingsClientSpec
|
||||
from . agent_service import AgentService
|
||||
from . graph_rag_client import GraphRagClientSpec
|
||||
from . tool_service import ToolService
|
||||
|
||||
|
|
|
|||
121
trustgraph-base/trustgraph/base/tool_service.py
Normal file
121
trustgraph-base/trustgraph/base/tool_service.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
|
||||
"""
|
||||
Tool invocation base class
|
||||
"""
|
||||
|
||||
import json
|
||||
from prometheus_client import Counter
|
||||
|
||||
from .. schema import ToolRequest, ToolResponse, Error
|
||||
from .. exceptions import TooManyRequests
|
||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
|
||||
default_concurrency = 1
|
||||
|
||||
class ToolService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(ToolService, self).__init__(**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = ToolRequest,
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = ToolResponse
|
||||
)
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "tool_invocation_metric"):
|
||||
__class__.tool_invocation_metric = Counter(
|
||||
'tool_invocation_count', 'Tool invocation count',
|
||||
["id", "flow", "name"],
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
response = await self.invoke_tool(
|
||||
request.name,
|
||||
json.loads(request.parameters) if request.parameters else {},
|
||||
)
|
||||
|
||||
if isinstance(response, str):
|
||||
await flow("response").send(
|
||||
ToolResponse(
|
||||
error=None,
|
||||
text=response,
|
||||
object=None,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
await flow("response").send(
|
||||
ToolResponse(
|
||||
error=None,
|
||||
text=None,
|
||||
object=json.dumps(response),
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
__class__.tool_invocation_metric.labels(
|
||||
id = self.id, flow = flow.name, name = request.name,
|
||||
).inc()
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
|
||||
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
await flow.producer["response"].send(
|
||||
ToolResponse(
|
||||
error=Error(
|
||||
type = "tool-error",
|
||||
message = str(e),
|
||||
),
|
||||
text=None,
|
||||
object=None,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
|
|
@ -16,6 +16,7 @@ from .translators.document_loading import DocumentTranslator, TextDocumentTransl
|
|||
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
|
||||
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
|
||||
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
||||
from .translators.embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
|
|
@ -88,6 +89,12 @@ TranslatorRegistry.register_service(
|
|||
PromptResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"tool",
|
||||
ToolRequestTranslator(),
|
||||
ToolResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"document-embeddings-query",
|
||||
DocumentEmbeddingsRequestTranslator(),
|
||||
|
|
|
|||
51
trustgraph-base/trustgraph/messaging/translators/tool.py
Normal file
51
trustgraph-base/trustgraph/messaging/translators/tool.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import json
|
||||
from typing import Dict, Any, Tuple
|
||||
from ...schema import ToolRequest, ToolResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
class ToolRequestTranslator(MessageTranslator):
|
||||
"""Translator for ToolRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> ToolRequest:
|
||||
# Handle both "name" and "parameters" input keys
|
||||
name = data.get("name", "")
|
||||
if "parameters" in data:
|
||||
parameters = json.dumps(data["parameters"])
|
||||
else:
|
||||
parameters = None
|
||||
|
||||
return ToolRequest(
|
||||
name = name,
|
||||
parameters = parameters,
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: ToolRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.name:
|
||||
result["name"] = obj.name
|
||||
if obj.parameters is not None:
|
||||
result["parameters"] = json.loads(obj.parameters)
|
||||
|
||||
return result
|
||||
|
||||
class ToolResponseTranslator(MessageTranslator):
|
||||
"""Translator for ToolResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> ToolResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: ToolResponse) -> Dict[str, Any]:
|
||||
|
||||
result = {}
|
||||
|
||||
if obj.text:
|
||||
result["text"] = obj.text
|
||||
if obj.object:
|
||||
result["object"] = json.loads(obj.object)
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: ToolResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -30,3 +30,22 @@ class EmbeddingsResponse(Record):
|
|||
error = Error()
|
||||
vectors = Array(Array(Double()))
|
||||
|
||||
############################################################################
|
||||
|
||||
# Tool request/response
|
||||
|
||||
class ToolRequest(Record):
|
||||
name = String()
|
||||
|
||||
# Parameters are JSON encoded
|
||||
parameters = String()
|
||||
|
||||
class ToolResponse(Record):
|
||||
error = Error()
|
||||
|
||||
# Plain text aka "unstructured"
|
||||
text = String()
|
||||
|
||||
# JSON-encoded object aka "structured"
|
||||
object = String()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue