From 5420a20d2901aea0efdebb7de4dbb048c4f316e1 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 10 Jun 2026 14:10:43 +0100 Subject: [PATCH] feat: per-caller Bearer token auth and new query tools for MCP server (#984) Replace the broken GATEWAY_SECRET auth (token was sent as a query parameter, silently ignored by the gateway) with end-to-end Bearer token forwarding. Each MCP caller gets a dedicated WebSocket authenticated via the gateway's in-band first-frame protocol, with whoami verification on first connect. Also fix and extend the tool surface: - embeddings: accept list of texts (was single string) - triples_query: use Term wire format with compact keys (was legacy Value format), add collection and graph parameters - sparql_query: new tool for SPARQL SELECT/ASK/CONSTRUCT/DESCRIBE - graphql_query: new tool for structured data (rows) GraphQL queries - all tools: add optional workspace parameter --- trustgraph-mcp/trustgraph/mcp_server/mcp.py | 1872 ++++++++--------- .../trustgraph/mcp_server/tg_socket.py | 168 +- 2 files changed, 1044 insertions(+), 996 deletions(-) diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index 7378db64..11b975b2 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -8,71 +8,180 @@ import logging import json import uuid import argparse -from dataclasses import dataclass +from dataclasses import dataclass, field from collections.abc import AsyncIterator from functools import partial from mcp.server.fastmcp import FastMCP, Context -from mcp.types import TextContent -from websockets.asyncio.client import connect +from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.server.auth.middleware.auth_context import get_access_token from trustgraph.base.logging import add_logging_args, setup_logging -from . tg_socket import WebSocketManager +from . tg_socket import WebSocketManager, _token_key + +logger = logging.getLogger(__name__) + + +# Wire-format Term type codes (match TermTranslator compact keys) +_TERM_TYPES = { + "iri": "i", + "literal": "l", + "blank": "b", +} + + +def _make_term(value: str, term_type: str) -> dict: + """Build a compact-key Term dict for the gateway wire format. + + Args: + value: The term value (IRI string, literal text, or blank node id). + term_type: One of "iri", "literal", "blank". + """ + t = _TERM_TYPES.get(term_type) + if t is None: + raise ValueError( + f"Unknown term type '{term_type}' — " + f"expected one of: {', '.join(_TERM_TYPES)}" + ) + + if t == "i": + return {"t": t, "i": value} + elif t == "l": + return {"t": t, "v": value} + elif t == "b": + return {"t": t, "d": value} + return {"t": t} + +# ── Security boundary: MCP client → MCP server ── +# The MCP client authenticates to this server via a Bearer token in the +# HTTP Authorization header. The SDK's auth middleware extracts and +# verifies the token before any tool handler runs. +# +# We implement a pass-through TokenVerifier: the gateway is the real +# authority, so we accept any non-empty Bearer token here and forward +# it to the gateway for validation. The gateway's in-band auth +# protocol and IAM regime decide whether the token is valid. +# +# This means an invalid token will connect to the MCP server but will +# fail when the first WebSocket auth frame is sent to the gateway. +# That is intentional — the gateway is the single source of truth. + + +class PassthroughTokenVerifier(TokenVerifier): + """Accept any non-empty Bearer token and forward it downstream. + + The TrustGraph gateway is the authority for token validation, not + this MCP server. We store the raw token in the AccessToken so that + tool handlers can retrieve it via ``get_access_token().token`` and + forward it to the gateway. + """ + + async def verify_token(self, token: str) -> AccessToken | None: + if not token: + return None + return AccessToken( + token=token, + client_id="mcp-caller", + scopes=[], + ) + @dataclass class AppContext: - sockets: dict[str, WebSocketManager] - websocket_url: str - gateway_token: str + sockets: dict[str, WebSocketManager] = field(default_factory=dict) + websocket_url: str = "" + @asynccontextmanager -async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]: +async def app_lifespan( + server: FastMCP, + websocket_url: str = "ws://api-gateway:8088/api/v1/socket", +) -> AsyncIterator[AppContext]: + """Manage per-server state: the pool of per-caller WebSocket + connections to the gateway.""" - """ - Manage application lifecycle with type-safe context - """ - - # Initialize on startup - sockets = {} + sockets: dict[str, WebSocketManager] = {} try: - yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token) + yield AppContext(sockets=sockets, websocket_url=websocket_url) finally: - # Cleanup on shutdown - logging.info("Shutting down context") + logger.info("Shutting down — closing %d WebSocket(s)", len(sockets)) - for k, manager in sockets.items(): - logging.info(f"Closing socket for {k}") - await manager.stop() + for key, manager in sockets.items(): + try: + await manager.stop() + except Exception as e: + logger.warning("Error closing socket %s: %s", key, e) - logging.info("Shutdown complete") + logger.info("Shutdown complete") -async def get_socket_manager(ctx): + +def _require_token() -> str: + """Extract the caller's Bearer token from the MCP auth context. + + Raises RuntimeError if no token is present (the caller did not + authenticate). + """ + # ── Security boundary: token extraction ── + # get_access_token() reads the contextvar set by the SDK's + # AuthContextMiddleware. The token was placed there by + # PassthroughTokenVerifier.verify_token() and is the raw Bearer + # value from the MCP client's Authorization header. + access = get_access_token() + if access is None or not access.token: + raise RuntimeError( + "Authentication required — send a Bearer token in the " + "Authorization header" + ) + return access.token + + +async def get_socket_manager(ctx, token): + """Return (or create) an authenticated WebSocket for this token. + + Each unique token gets its own WebSocket connection so that + gateway-side identity, workspace binding, and capability scoping + are preserved per caller. + """ lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url - gateway_token = lifespan_context.gateway_token - if "default" in sockets: - logging.info("Return existing socket manager") - return sockets["default"] + key = _token_key(token) - logging.info(f"Opening socket to {websocket_url}...") + if key in sockets: + manager = sockets[key] + if manager.socket is not None: + return manager + # Socket was closed (e.g. server-side timeout) — reconnect. + del sockets[key] - # Create manager with empty pending requests - manager = WebSocketManager(websocket_url, token=gateway_token) + logger.info("Opening authenticated WebSocket to %s …", websocket_url) - # Start reader task with the proper manager + manager = WebSocketManager(websocket_url, token=token) await manager.start() - sockets["default"] = manager + # Verify the token is valid by calling whoami. This confirms the + # gateway accepted the token and gives us the caller's identity. + try: + identity = await manager.whoami() + logger.info( + "WebSocket ready — caller: %s", + identity.get("handle", "unknown"), + ) + except Exception as e: + await manager.stop() + raise RuntimeError( + f"Token rejected by gateway (whoami failed): {e}" + ) from e - logging.info("Return new socket manager") + sockets[key] = manager return manager + @dataclass class EmbeddingsResponse: vectors: List[List[float]] @@ -182,10 +291,23 @@ class PutConfigResponse: class DeleteConfigResponse: pass +@dataclass +class SparqlQueryResponse: + query_type: str + variables: List[str] + bindings: List[Dict[str, Any]] + ask_result: bool + triples: List[Dict[str, Any]] + +@dataclass +class GraphQLQueryResponse: + data: Any + errors: List[Dict[str, Any]] + @dataclass class GetPromptsResponse: prompts: List[str] - + @dataclass class GetPromptResponse: prompt: Dict[str, Any] @@ -194,31 +316,61 @@ class GetPromptResponse: class GetSystemPromptResponse: prompt: str + class McpServer: - def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""): + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8000, + websocket_url: str = "ws://api-gateway:8088/api/v1/socket", + auth_issuer: str = "", + auth_resource_url: str = "", + ): self.host = host self.port = port self.websocket_url = websocket_url - self.gateway_token = gateway_token - # Create a partial function to pass websocket_url to app_lifespan - lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token) - + lifespan_with_url = partial( + app_lifespan, websocket_url=websocket_url, + ) + + # ── Security: MCP-level auth configuration ── + # The SDK requires AuthSettings whenever a token_verifier is + # present. The issuer_url tells MCP clients where to obtain + # tokens; resource_server_url identifies this server in OAuth + # protected-resource metadata. + # + # The PassthroughTokenVerifier accepts any non-empty Bearer + # token — real validation happens at the gateway. This is + # intentional: the gateway is the single source of truth for + # identity and capability checks. + from mcp.server.auth.settings import AuthSettings + + auth_settings = AuthSettings( + issuer_url=auth_issuer or f"http://{host}:{port}", + resource_server_url=auth_resource_url or f"http://{host}:{port}", + ) + self.mcp = FastMCP( - "TrustGraph", dependencies=["trustgraph-base"], - host=self.host, port=self.port, + "TrustGraph", + dependencies=["trustgraph-base"], + host=self.host, + port=self.port, lifespan=lifespan_with_url, + token_verifier=PassthroughTokenVerifier(), + auth=auth_settings, ) self._register_tools() - + def _register_tools(self): """Register all MCP tools""" - # Register all the tools that were previously registered globally self.mcp.tool()(self.embeddings) self.mcp.tool()(self.text_completion) self.mcp.tool()(self.graph_rag) self.mcp.tool()(self.agent) self.mcp.tool()(self.triples_query) + self.mcp.tool()(self.sparql_query) + self.mcp.tool()(self.graphql_query) self.mcp.tool()(self.graph_embeddings_query) self.mcp.tool()(self.get_config_all) self.mcp.tool()(self.get_config) @@ -243,67 +395,69 @@ class McpServer: self.mcp.tool()(self.load_document) self.mcp.tool()(self.remove_document) self.mcp.tool()(self.add_processing) - + def run(self): """Run the MCP server""" self.mcp.run(transport="streamable-http") + async def _get_manager(self, ctx): + """Get an authenticated WebSocket manager for the current caller. + + Extracts the Bearer token from the MCP auth context and returns + a per-token WebSocket connection to the gateway. + """ + token = _require_token() + return await get_socket_manager(ctx, token) + async def embeddings( self, - text: str, + texts: List[str], flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> EmbeddingsResponse: """ - Generate vector embeddings for the given text using TrustGraph's embedding models. - + Generate vector embeddings for the given texts using TrustGraph's embedding models. + This tool converts text into high-dimensional vectors that capture semantic meaning, enabling similarity searches, clustering, and other vector-based operations. - + Args: - text: The input text to convert into embeddings. Can be a sentence, paragraph, - or document. The text will be processed by the configured embedding model. + texts: List of input texts to convert into embeddings. Each text can be a + sentence, paragraph, or document. flow_id: Optional flow identifier to use for processing (default: "default"). Different flows may use different embedding models or configurations. - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: - EmbeddingsResponse containing a list of vectors. Each vector is a list of floats - representing the text's semantic embedding in the model's vector space. - - Example usage: - - Convert a query into embeddings for similarity search - - Generate embeddings for documents before storing them - - Create embeddings for comparison with existing knowledge + EmbeddingsResponse containing a list of vectors, one per input text. """ - logging.info("Embeddings request made") + logger.info("Embeddings request") if flow_id is None: flow_id = "default" - manager = await get_socket_manager(ctx, "trustgraph") + manager = await self._get_manager(ctx) - if ctx is None: - raise RuntimeError("No context provided") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Computing embeddings via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - await ctx.session.send_log_message( - level="info", - data=f"Computing embeddings via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, + request_data = {"texts": texts} + + gen = manager.request( + "embeddings", request_data, flow_id, workspace=workspace, ) - # Send websocket request - request_data = {"text": text} - logging.info("making request") - - gen = manager.request("embeddings", request_data, flow_id) - async for response in gen: - - # Extract vectors from response vectors = response.get("vectors", [[]]) break - + return EmbeddingsResponse(vectors=vectors) async def text_completion( @@ -311,62 +465,47 @@ class McpServer: prompt: str, system: str | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> TextCompletionResponse: """ Generate text completions using TrustGraph's language models. - - This tool sends prompts to configured language models and returns generated text. - It supports both user prompts and system instructions for controlling generation. - + Args: prompt: The main prompt or question to send to the language model. - This is the primary input that guides the model's response. system: Optional system prompt that sets the context, role, or behavior - for the AI assistant (e.g., "You are a helpful coding assistant"). - System prompts influence how the model interprets and responds. - flow_id: Optional flow identifier (default: "default"). Different flows - may use different models, parameters, or processing pipelines. - + for the AI assistant. + flow_id: Optional flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: TextCompletionResponse containing the generated text response from the model. - - Example usage: - - Ask questions and get AI-generated answers - - Generate code, documentation, or creative content - - Perform text analysis, summarization, or transformation tasks - - Use system prompts to control tone, style, or domain expertise """ if system is None: system = "" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - # Use websocket if context is available - logging.info("Text completion request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Generating text completion via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Generating text completion via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Send websocket request request_data = {"system": system, "prompt": prompt} - gen = manager.request("text-completion", request_data, flow_id) + gen = manager.request( + "text-completion", request_data, flow_id, workspace=workspace, + ) async for response in gen: - - # Extract vectors from response text = response.get("response", "") break - + return TextCompletionResponse(response=text) async def graph_rag( @@ -378,58 +517,43 @@ class McpServer: max_subgraph_size: int | None = None, max_path_length: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> GraphRagResponse: """ Perform Graph-based Retrieval Augmented Generation (GraphRAG) queries. - + GraphRAG combines knowledge graph traversal with language model generation to provide - contextually rich answers. It explores relationships between entities to build relevant - context before generating responses. - + contextually rich answers. + Args: question: The question or query to answer using the knowledge graph. - The system will find relevant entities and relationships to inform the response. collection: Knowledge collection to query (default: "default"). - Different collections may contain domain-specific knowledge. entity_limit: Maximum number of entities to retrieve during graph traversal. - Higher limits provide more context but increase processing time. triple_limit: Maximum number of relationship triples to consider. - Controls the depth of relationship exploration. max_subgraph_size: Maximum size of the subgraph to extract for context. - Larger subgraphs provide richer context but use more resources. max_path_length: Maximum path length to traverse in the knowledge graph. - Longer paths can discover distant but relevant relationships. flow_id: Processing flow to use (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: GraphRagResponse containing the generated answer informed by knowledge graph context. - - Example usage: - - Answer complex questions requiring multi-hop reasoning - - Explore relationships between entities in your knowledge base - - Generate responses grounded in structured knowledge - - Perform research queries across connected information """ if collection is None: collection = "default" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("GraphRAG request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing GraphRAG query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Processing GraphRAG query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with all parameters request_data = { "query": question } @@ -440,20 +564,19 @@ class McpServer: if max_subgraph_size: request_data["max_subgraph_size"] = max_subgraph_size if max_path_length: request_data["max_path_length"] = max_path_length - gen = manager.request("graph-rag", request_data, flow_id) + gen = manager.request( + "graph-rag", request_data, flow_id, workspace=workspace, + ) text_chunks = [] async for response in gen: - # Handle new message format with message_type message_type = response.get("message_type", "chunk") - # Only collect text from chunk messages if message_type == "chunk": chunk_text = response.get("response", "") if chunk_text: text_chunks.append(chunk_text) - # Check if session is complete if response.get("end_of_session"): break @@ -464,404 +587,447 @@ class McpServer: question: str, collection: str | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> AgentResponse: """ Execute intelligent agent queries with reasoning and tool usage capabilities. - - The agent can perform complex multi-step reasoning, use tools, and provide - detailed thought processes. It's designed for tasks requiring planning, - analysis, and iterative problem-solving. - + Args: - question: The question or task for the agent to solve. Can be complex - queries requiring multiple steps, analysis, or tool usage. + question: The question or task for the agent to solve. collection: Knowledge collection the agent can access (default: "default"). - Determines what information and tools are available. - flow_id: Agent workflow to use (default: "default"). Different flows - may have different capabilities, tools, or reasoning strategies. - + flow_id: Agent workflow to use (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: AgentResponse containing the final answer after the agent's reasoning process. - During execution, you'll see intermediate thoughts and observations. - - Example usage: - - Solve complex analytical problems requiring multiple steps - - Perform research tasks across multiple information sources - - Handle queries that need tool usage and decision-making - - Get detailed explanations of reasoning processes - - Note: This tool provides real-time updates on the agent's thinking process - through log messages, so you can follow its reasoning steps. """ if collection is None: collection = "default" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Agent request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing agent query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Processing agent query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with all parameters request_data = { "question": question } if collection: request_data["collection"] = collection - gen = manager.request("agent", request_data, flow_id) + gen = manager.request( + "agent", request_data, flow_id, workspace=workspace, + ) async for response in gen: - logging.debug(f"Agent response: {response}") + logger.debug("Agent response: %s", response) - if "thought" in response: - await ctx.session.send_log_message( - level="info", - data=f"Thinking: {response['thought']}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + if "thought" in response: + await ctx.session.send_log_message( + level="info", + data=f"Thinking: {response['thought']}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - if "observation" in response: - await ctx.session.send_log_message( - level="info", - data=f"Observation: {response['observation']}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if "observation" in response: + await ctx.session.send_log_message( + level="info", + data=f"Observation: {response['observation']}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - # Extract vectors from response if "answer" in response: answer = response.get("answer", "") return AgentResponse(answer=answer) async def triples_query( self, - s_v: str | None = None, - s_e: bool | None = None, - p_v: str | None = None, - p_e: bool | None = None, - o_v: str | None = None, - o_e: bool | None = None, + s: str | None = None, + s_type: str | None = None, + p: str | None = None, + p_type: str | None = None, + o: str | None = None, + o_type: str | None = None, + collection: str | None = None, + graph: str | None = None, limit: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> TriplesQueryResponse: """ Query knowledge graph triples using subject-predicate-object patterns. - - Knowledge graphs store information as triples (subject, predicate, object). - This tool allows flexible querying by specifying any combination of these - components, with wildcards for unspecified parts. - + + Each of s, p, o is an RDF term value. Use the corresponding _type + parameter to specify the term kind: + - "iri" (default for s and p): an IRI / entity reference + - "literal" (default for o): a plain literal value + - "blank": a blank node identifier + Args: - s_v: Subject value to match (e.g., "John", "Apple Inc."). Leave None for wildcard. - s_e: Whether subject should be treated as an entity (True) or literal (False). - p_v: Predicate/relationship value (e.g., "works_for", "type_of"). Leave None for wildcard. - p_e: Whether predicate should be treated as an entity (True) or literal (False). - o_v: Object value to match (e.g., "Engineer", "Company"). Leave None for wildcard. - o_e: Whether object should be treated as an entity (True) or literal (False). + s: Subject value to match. Leave None for wildcard. + s_type: Subject term type: "iri" (default), "literal", or "blank". + p: Predicate value to match. Leave None for wildcard. + p_type: Predicate term type: "iri" (default), "literal", or "blank". + o: Object value to match. Leave None for wildcard. + o_type: Object term type: "iri", "literal" (default), or "blank". + collection: Knowledge collection to query (default: "default"). + graph: Named graph IRI to restrict the query. None = default graph, + "*" = all graphs. limit: Maximum number of triples to return (default: 20). flow_id: Processing flow identifier (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: TriplesQueryResponse containing matching triples from the knowledge graph. - - Example queries: - - Find all relationships for an entity: s_v="John", others None - - Find all instances of a relationship: p_v="works_for", others None - - Find specific facts: s_v="John", p_v="works_for", o_v=None - - Explore entity types: p_v="type_of", others None - - Use this for: - - Exploring knowledge graph structure - - Finding specific facts or relationships - - Discovering connections between entities - - Validating or debugging knowledge content """ if flow_id is None: flow_id = "default" if limit is None: limit = 20 + if collection is None: collection = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Triples query request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing triples query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Processing triples query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with Value objects request_data = { - "limit": limit + "limit": limit, + "collection": collection, } - # Add subject if provided - if s_v is not None: - request_data["s"] = {"v": s_v, "e": s_e } + if s is not None: + request_data["s"] = _make_term(s, s_type or "iri") - # Add predicate if provided - if p_v is not None: - request_data["p"] = {"v": p_v, "e": p_e } + if p is not None: + request_data["p"] = _make_term(p, p_type or "iri") - # Add object if provided - if o_v is not None: - request_data["o"] = {"v": o_v, "e": o_e } + if o is not None: + request_data["o"] = _make_term(o, o_type or "literal") - gen = manager.request("triples", request_data, flow_id) + if graph is not None: + request_data["g"] = graph + + gen = manager.request( + "triples", request_data, flow_id, workspace=workspace, + ) async for response in gen: - # Extract response data triples = response.get("response", []) break - + return TriplesQueryResponse(triples=triples) + async def sparql_query( + self, + query: str, + collection: str | None = None, + limit: int | None = None, + flow_id: str | None = None, + workspace: str | None = None, + ctx: Context = None, + ) -> SparqlQueryResponse: + """ + Execute a SPARQL query against the knowledge graph. + + Supports SELECT, ASK, CONSTRUCT, and DESCRIBE query forms. + + Args: + query: SPARQL query string (e.g. "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"). + collection: Knowledge collection to query (default: "default"). + limit: Safety limit on number of results (default: 10000). + flow_id: Processing flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + + Returns: + SparqlQueryResponse containing the query results. The structure depends + on query type: + - SELECT: variables (column names) and bindings (rows of Term values) + - ASK: ask_result (boolean) + - CONSTRUCT/DESCRIBE: triples + """ + + if collection is None: collection = "default" + if flow_id is None: flow_id = "default" + if limit is None: limit = 10000 + + manager = await self._get_manager(ctx) + + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing SPARQL query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + + request_data = { + "query": query, + "collection": collection, + "limit": limit, + } + + gen = manager.request( + "sparql", request_data, flow_id, workspace=workspace, + ) + + async for response in gen: + query_type = response.get("query-type", "") + return SparqlQueryResponse( + query_type=query_type, + variables=response.get("variables", []), + bindings=response.get("bindings", []), + ask_result=response.get("ask-result", False), + triples=response.get("triples", []), + ) + + async def graphql_query( + self, + query: str, + collection: str | None = None, + variables: Dict[str, Any] | None = None, + operation_name: str | None = None, + flow_id: str | None = None, + workspace: str | None = None, + ctx: Context = None, + ) -> GraphQLQueryResponse: + """ + Execute a GraphQL query against structured data (rows). + + Queries structured data schemas that have been loaded into TrustGraph. + The available types and fields depend on the schemas configured in the + target workspace. + + Args: + query: GraphQL query string (e.g. '{ customers(where: {status: {eq: "active"}}) { id name } }'). + collection: Data collection to query (default: "default"). + variables: Optional GraphQL variables as a dict. + operation_name: Optional operation name for multi-operation documents. + flow_id: Processing flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + + Returns: + GraphQLQueryResponse containing data (the query result) and errors + (any GraphQL field-level errors). + """ + + if collection is None: collection = "default" + if flow_id is None: flow_id = "default" + + manager = await self._get_manager(ctx) + + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing GraphQL query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + + request_data = { + "query": query, + "collection": collection, + "variables": variables or {}, + } + + if operation_name is not None: + request_data["operation_name"] = operation_name + + gen = manager.request( + "rows", request_data, flow_id, workspace=workspace, + ) + + async for response in gen: + return GraphQLQueryResponse( + data=response.get("data"), + errors=response.get("errors", []), + ) + async def graph_embeddings_query( self, vectors: List[List[float]], limit: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> GraphEmbeddingsQueryResponse: """ Find entities in the knowledge graph using vector similarity search. - - This tool performs semantic search by comparing embedding vectors to find - the most similar entities in the knowledge graph. It's useful for finding - conceptually related information even when exact text matches don't exist. - + Args: - vectors: List of embedding vectors to search with. Each vector should be - a list of floats representing semantic embeddings (typically from - the embeddings tool). Multiple vectors can be provided for batch queries. + vectors: List of embedding vectors to search with. limit: Maximum number of similar entities to return (default: 20). - Higher limits provide more results but may include less relevant matches. flow_id: Processing flow identifier (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: - GraphEmbeddingsQueryResponse containing entities ranked by similarity to the - input vectors, along with similarity scores and entity metadata. - - Example workflow: - 1. Use the 'embeddings' tool to convert text to vectors - 2. Use this tool to find similar entities in the knowledge graph - 3. Explore the returned entities for relevant information - - Use this for: - - Semantic search across knowledge entities - - Finding conceptually similar content - - Discovering related entities without exact keyword matches - - Building recommendation systems based on entity similarity + GraphEmbeddingsQueryResponse containing entities ranked by similarity. """ if flow_id is None: flow_id = "default" if limit is None: limit = 20 - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Graph embeddings query request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing graph embeddings query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Processing graph embeddings query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data request_data = { "vectors": vectors, "limit": limit } - gen = manager.request("graph-embeddings", request_data, flow_id) + gen = manager.request( + "graph-embeddings", request_data, flow_id, workspace=workspace, + ) async for response in gen: - # Extract entities from response entities = response.get("entities", []) break - + return GraphEmbeddingsQueryResponse(entities=entities) async def get_config_all( self, + workspace: str | None = None, ctx: Context = None, ) -> ConfigResponse: """ Retrieve the complete TrustGraph system configuration. - - This tool returns all configuration settings for the TrustGraph system, - including model configurations, API keys, flow definitions, and system parameters. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - ConfigResponse containing the full configuration as a nested dictionary - with all system settings, organized by category (e.g., models, flows, storage). - - Use this for: - - Inspecting current system configuration - - Debugging configuration issues - - Understanding available models and settings - - Auditing system setup and parameters + ConfigResponse containing the full configuration as a nested dictionary. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get config all request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving all configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving all configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) break - + return ConfigResponse(config=config) async def get_config( self, keys: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> ConfigGetResponse: """ Retrieve specific configuration values by key. - - This tool allows you to fetch specific configuration settings without - retrieving the entire configuration. Useful for checking particular - settings or API keys. - + Args: - keys: List of configuration keys to retrieve. Each key should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings', 'storage') - - 'key': Specific setting name within that category - + keys: List of configuration keys to retrieve. Each key should be a dict with + 'type' and 'key' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: ConfigGetResponse containing the requested configuration values. - - Example keys: - - {'type': 'llm', 'key': 'openai.model'} - - {'type': 'embeddings', 'key': 'default.model'} - - {'type': 'storage', 'key': 'database.url'} - - Use this for: - - Checking specific model configurations - - Validating API key settings - - Inspecting individual system parameters """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving specific configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving specific configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get", "keys": keys } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: values = response.get("values", []) break - + return ConfigGetResponse(values=values) async def put_config( self, values: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> PutConfigResponse: """ Update system configuration values. - - This tool allows you to modify TrustGraph system settings, such as - model parameters, API keys, and system behavior configurations. - + Args: - values: List of configuration updates. Each update should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings') - - 'key': Specific setting name to update - - 'value': New value for the setting - + values: List of configuration updates. Each should be a dict with + 'type', 'key', and 'value' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: PutConfigResponse confirming the configuration update. - - Example updates: - - {'type': 'llm', 'key': 'openai.model', 'value': 'gpt-4'} - - {'type': 'embeddings', 'key': 'batch_size', 'value': '100'} - - Use this for: - - Switching between different models - - Updating API credentials - - Modifying system behavior parameters - - Configuring processing settings - - Note: Configuration changes may require system restart to take effect. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Put config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Updating configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Updating configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "put", "values": values } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: return PutConfigResponse() @@ -869,97 +1035,73 @@ class McpServer: async def delete_config( self, keys: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> DeleteConfigResponse: """ Delete specific configuration entries from the system. - - This tool removes configuration settings, reverting them to system defaults - or disabling specific features. - + Args: - keys: List of configuration keys to delete. Each key should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings') - - 'key': Specific setting name to remove - + keys: List of configuration keys to delete. Each should be a dict with + 'type' and 'key' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: DeleteConfigResponse confirming the deletion. - - Use this for: - - Removing custom model configurations - - Clearing API credentials - - Resetting settings to defaults - - Cleaning up obsolete configurations - - Warning: Deleting essential configuration may cause system functionality - to be disabled until properly reconfigured. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Delete config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Deleting configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Deleting configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "delete", "keys": keys } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: return DeleteConfigResponse() async def get_prompts( self, + workspace: str | None = None, ctx: Context = None, ) -> GetPromptsResponse: """ List all available prompt templates in the system. - - Prompt templates are reusable prompts that can be used with language models - for consistent behavior across different queries and use cases. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: GetPromptsResponse containing a list of available prompt template IDs. - Each ID can be used with get_prompt to retrieve the full template. - - Use this for: - - Discovering available prompt templates - - Exploring pre-configured prompts for different tasks - - Finding templates for specific use cases - - Understanding what prompt options are available """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get prompts request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving prompt templates via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving prompt templates via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -971,49 +1113,36 @@ class McpServer: async def get_prompt( self, prompt_id: str, + workspace: str | None = None, ctx: Context = None, ) -> GetPromptResponse: """ Retrieve a specific prompt template by ID. - - Prompt templates contain structured prompts with placeholders, instructions, - and metadata for specific tasks or domains. - + Args: prompt_id: The unique identifier of the prompt template to retrieve. - Use get_prompts to see available template IDs. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - GetPromptResponse containing the complete prompt template with its - structure, placeholders, and usage instructions. - - Use this for: - - Examining prompt template structure - - Understanding how to use specific templates - - Copying or modifying existing prompts - - Learning prompt engineering patterns + GetPromptResponse containing the complete prompt template. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get prompt request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving prompt template '{prompt_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving prompt template '{prompt_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -1025,44 +1154,35 @@ class McpServer: async def get_system_prompt( self, + workspace: str | None = None, ctx: Context = None, ) -> GetSystemPromptResponse: """ Retrieve the current system prompt configuration. - - The system prompt defines the default behavior, personality, and instructions - for language models across the TrustGraph system. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - GetSystemPromptResponse containing the system prompt text and configuration. - - Use this for: - - Understanding default AI behavior settings - - Checking current system-wide prompt configuration - - Auditing AI personality and instruction settings - - Debugging unexpected AI responses + GetSystemPromptResponse containing the system prompt text. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get system prompt request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving system prompt via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving system prompt via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -1073,51 +1193,39 @@ class McpServer: async def get_token_costs( self, + workspace: str | None = None, ctx: Context = None, ) -> ConfigTokenCostsResponse: """ Retrieve token pricing information for all configured AI models. - - This tool provides cost information for input and output tokens across - different language models, helping with budget planning and cost optimization. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - ConfigTokenCostsResponse containing pricing data for each model including: - - Model name/identifier - - Input token cost (per token) - - Output token cost (per token) - - Use this for: - - Estimating costs for different models - - Choosing cost-effective models for tasks - - Budget planning and cost analysis - - Monitoring and optimizing AI spending + ConfigTokenCostsResponse containing pricing data for each model. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get token costs request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving token costs via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving token costs via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "getvalues", "type": "token-costs" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: values = response.get("values", []) - # Transform to match TypeScript API format costs = [] for item in values: try: @@ -1130,106 +1238,89 @@ class McpServer: except (json.JSONDecodeError, AttributeError): continue break - + return ConfigTokenCostsResponse(costs=costs) async def get_knowledge_cores( self, + workspace: str | None = None, ctx: Context = None, ) -> KnowledgeCoresResponse: """ List all available knowledge graph cores in the current workspace. - Knowledge cores are packaged collections of structured knowledge that can - be loaded into the system for querying and reasoning. They contain entities, - relationships, and facts organized as knowledge graphs. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: KnowledgeCoresResponse containing a list of available knowledge core IDs. - - Use this for: - - Discovering available knowledge collections - - Understanding what knowledge domains are accessible - - Planning which cores to load for specific tasks - - Managing knowledge resources """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get knowledge cores request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving knowledge graph cores via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving knowledge graph cores via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-kg-cores", } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: ids = response.get("ids", []) break - + return KnowledgeCoresResponse(ids=ids) async def delete_kg_core( self, core_id: str, + workspace: str | None = None, ctx: Context = None, ) -> DeleteKgCoreResponse: """ Permanently delete a knowledge graph core. - This operation removes a knowledge core from storage. Use with caution - as this action cannot be undone. - Args: core_id: Unique identifier of the knowledge core to delete. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: DeleteKgCoreResponse confirming the deletion. - - Use this for: - - Cleaning up obsolete knowledge cores - - Removing test or experimental data - - Managing storage space - - Maintaining organized knowledge collections - - Warning: This permanently deletes the knowledge core and all its data. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Delete KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Deleting knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Deleting knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "delete-kg-core", "id": core_id, } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: break - + return DeleteKgCoreResponse() async def load_kg_core( @@ -1237,46 +1328,34 @@ class McpServer: core_id: str, flow: str, collection: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> LoadKgCoreResponse: """ Load a knowledge graph core into the active system for querying. - This operation makes a knowledge core available for GraphRAG queries, - triple searches, and other knowledge-based operations. - Args: core_id: Unique identifier of the knowledge core to load. - flow: Processing flow to use for loading the core. Different flows - may apply different processing, indexing, or optimization steps. - collection: Target collection name (default: "default"). The loaded - knowledge will be available under this collection name. + flow: Processing flow to use for loading the core. + collection: Target collection name (default: "default"). + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: LoadKgCoreResponse confirming the core has been loaded. - - Use this for: - - Making knowledge cores available for queries - - Switching between different knowledge domains - - Loading domain-specific knowledge for tasks - - Preparing knowledge for GraphRAG operations """ if collection is None: collection = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Load KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Loading knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Loading knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "load-kg-core", @@ -1285,292 +1364,241 @@ class McpServer: "collection": collection } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: break - + return LoadKgCoreResponse() async def get_kg_core( self, core_id: str, + workspace: str | None = None, ctx: Context = None, ) -> GetKgCoreResponse: """ Download and retrieve the complete content of a knowledge graph core. - This tool streams the entire content of a knowledge core, returning all - entities, relationships, and metadata. Due to potentially large data sizes, - the content is streamed in chunks. - Args: core_id: Unique identifier of the knowledge core to retrieve. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: GetKgCoreResponse containing all chunks of the knowledge core data. - Each chunk contains part of the knowledge graph structure. - - Use this for: - - Examining knowledge core content and structure - - Debugging knowledge graph data - - Exporting knowledge for backup or analysis - - Understanding the scope and quality of knowledge - - Note: Large knowledge cores may take significant time to download. - Progress updates are provided through log messages during streaming. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-kg-core", "id": core_id, } - # Collect all streaming responses chunks = [] - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: - # Check for end of stream if response.get("eos", False): - await ctx.session.send_log_message( - level="info", - data=f"Completed streaming KG core data", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Completed streaming KG core data", + logger="notification_stream", + related_request_id=ctx.request_id, + ) break else: chunks.append(response) - await ctx.session.send_log_message( - level="info", - data=f"Received KG core chunk ({len(chunks)} chunks so far)", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Received KG core chunk ({len(chunks)} chunks so far)", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + return GetKgCoreResponse(chunks=chunks) async def get_flows( self, + workspace: str | None = None, ctx: Context = None, ) -> FlowsResponse: """ List all available processing flows in the system. - - Flows define processing pipelines for different types of operations - (e.g., document processing, knowledge extraction, query handling). - Each flow encapsulates a specific workflow with configured steps. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: FlowsResponse containing a list of available flow identifiers. - - Use this for: - - Discovering available processing workflows - - Understanding what processing options are available - - Choosing appropriate flows for specific tasks - - Planning workflow-based operations """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flows request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving available flows via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving available flows via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-flows" } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: flow_ids = response.get("flow-ids", []) break - + return FlowsResponse(flow_ids=flow_ids) async def get_flow( self, flow_id: str, + workspace: str | None = None, ctx: Context = None, ) -> FlowResponse: """ Retrieve the complete definition of a specific processing flow. - - This tool returns the detailed configuration, steps, and parameters - of a processing flow, showing how it processes data and what operations it performs. - + Args: flow_id: Unique identifier of the flow to retrieve. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - FlowResponse containing the complete flow definition including: - - Flow configuration and parameters - - Processing steps and their order - - Input/output specifications - - Dependencies and requirements - - Use this for: - - Understanding how specific flows work - - Debugging flow processing issues - - Learning flow configuration patterns - - Customizing or duplicating flows + FlowResponse containing the complete flow definition. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow definition for '{flow_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving flow definition for '{flow_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-flow", "flow-id": flow_id, } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: flow_data = response.get("flow", "{}") - # Parse JSON flow definition as done in TypeScript flow = json.loads(flow_data) if isinstance(flow_data, str) else flow_data break - + return FlowResponse(flow=flow) async def get_flow_classes( self, + workspace: str | None = None, ctx: Context = None, ) -> FlowClassesResponse: """ List all available flow class templates. - - Flow classes are templates that define types of processing workflows. - They serve as blueprints for creating specific flow instances with - customized parameters. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: FlowClassesResponse containing a list of available flow class names. - - Use this for: - - Discovering available flow templates - - Understanding what types of processing are supported - - Planning new flow creation - - Exploring system capabilities """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow classes request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow classes via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving flow classes via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-classes" } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: class_names = response.get("class-names", []) break - + return FlowClassesResponse(class_names=class_names) async def get_flow_class( self, class_name: str, + workspace: str | None = None, ctx: Context = None, ) -> FlowClassResponse: """ Retrieve the definition of a specific flow class template. - - Flow classes define the structure, parameters, and capabilities of - flow types. This tool returns the class specification including - configurable parameters and processing logic. - + Args: class_name: Name of the flow class to retrieve. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - FlowClassResponse containing the flow class definition with: - - Class parameters and configuration options - - Processing capabilities and requirements - - Usage instructions and examples - - Use this for: - - Understanding flow class capabilities - - Learning how to configure new flows - - Troubleshooting flow creation issues - - Exploring advanced flow features + FlowClassResponse containing the flow class definition. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow class request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow class definition for '{class_name}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving flow class definition for '{class_name}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-class", "class-name": class_name } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: class_def_data = response.get("class-definition", "{}") - # Parse JSON class definition as done in TypeScript class_definition = json.loads(class_def_data) if isinstance(class_def_data, str) else class_def_data break - + return FlowClassResponse(class_definition=class_definition) async def start_flow( @@ -1578,43 +1606,32 @@ class McpServer: flow_id: str, class_name: str, description: str, + workspace: str | None = None, ctx: Context = None, ) -> StartFlowResponse: """ Create and start a new processing flow instance. - - This tool creates a new flow based on a flow class template and starts - it running. The flow will begin processing according to its configuration. - + Args: flow_id: Unique identifier for the new flow instance. class_name: Flow class template to use for creating the flow. - Use get_flow_classes to see available classes. description: Human-readable description of the flow's purpose. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: StartFlowResponse confirming the flow has been started. - - Use this for: - - Creating new processing workflows - - Starting automated processing tasks - - Launching background operations - - Initiating data processing pipelines """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Start flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "start-flow", @@ -1623,162 +1640,135 @@ class McpServer: "description": description } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: break - + return StartFlowResponse() async def stop_flow( self, flow_id: str, + workspace: str | None = None, ctx: Context = None, ) -> StopFlowResponse: """ Stop a running flow instance. - - This tool gracefully stops a running flow, allowing it to complete - current operations before shutting down. - + Args: flow_id: Unique identifier of the flow instance to stop. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: StopFlowResponse confirming the flow has been stopped. - - Use this for: - - Stopping unwanted or completed flows - - Managing system resources - - Interrupting long-running processes - - Maintaining flow lifecycle """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Stop flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Stopping flow '{flow_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Stopping flow '{flow_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "stop-flow", "flow-id": flow_id } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: break - + return StopFlowResponse() async def get_documents( self, + workspace: str | None = None, ctx: Context = None, ) -> DocumentsResponse: """ List all documents stored in the TrustGraph document library. - This tool returns metadata for all documents that have been uploaded - to the system, including their processing status and properties. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: - DocumentsResponse containing metadata for each document including: - - Document ID and title - - Upload timestamp - - MIME type and size information - - Tags and custom metadata - - Processing status - - Use this for: - - Browsing available documents - - Managing document collections - - Finding documents by metadata - - Auditing document storage + DocumentsResponse containing metadata for each document. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get documents request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving documents list via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving documents list via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-documents", } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: document_metadatas = response.get("document-metadatas", []) break - + return DocumentsResponse(document_metadatas=document_metadatas) async def get_processing( self, + workspace: str | None = None, ctx: Context = None, ) -> ProcessingResponse: """ List all documents currently in the processing queue. - This tool shows documents that are being processed or waiting to be - processed, along with their processing status and configuration. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: - ProcessingResponse containing processing metadata including: - - Processing job ID and document ID - - Processing flow and status - - Target collection - - Timestamp and progress information - - Use this for: - - Monitoring document processing progress - - Debugging processing issues - - Managing processing queues - - Understanding system workload + ProcessingResponse containing processing metadata. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get processing request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving processing list via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving processing list via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-processing", } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: processing_metadatas = response.get("processing-metadatas", []) break - + return ProcessingResponse(processing_metadatas=processing_metadatas) async def load_document( @@ -1790,50 +1780,39 @@ class McpServer: title: str = "", comments: str = "", tags: List[str] | None = None, + workspace: str | None = None, ctx: Context = None, ) -> LoadDocumentResponse: """ Upload a document to the TrustGraph document library. - This tool stores documents with rich metadata for later processing, - search, and knowledge extraction. Documents can be text files, PDFs, - or other supported formats. - Args: document: The document content as a string. For binary files, this should be base64-encoded content. document_id: Optional unique identifier. If not provided, one will be generated. metadata: Optional list of custom metadata key-value pairs. - mime_type: MIME type of the document (e.g., 'text/plain', 'application/pdf'). + mime_type: MIME type of the document. title: Human-readable title for the document. comments: Optional description or notes about the document. - tags: List of tags for categorizing and finding the document. + tags: List of tags for categorizing the document. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: LoadDocumentResponse confirming the document has been stored. - - Use this for: - - Adding new documents to the knowledge base - - Storing reference materials and data sources - - Building document collections for processing - - Importing external content for analysis """ if tags is None: tags = [] - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Load document request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Loading document to library via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Loading document to library via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) import time timestamp = int(time.time()) @@ -1852,63 +1831,55 @@ class McpServer: "content": document } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return LoadDocumentResponse() async def remove_document( self, document_id: str, + workspace: str | None = None, ctx: Context = None, ) -> RemoveDocumentResponse: """ Permanently remove a document from the library. - This operation deletes a document and all its associated metadata. - Use with caution as this action cannot be undone. - Args: document_id: Unique identifier of the document to remove. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: RemoveDocumentResponse confirming the document has been deleted. - - Use this for: - - Cleaning up obsolete or incorrect documents - - Managing storage space - - Removing sensitive or inappropriate content - - Maintaining organized document collections - - Warning: This permanently deletes the document and all its metadata. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Remove document request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Removing document '{document_id}' from library via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Removing document '{document_id}' from library via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "remove-document", "document-id": document_id, } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return RemoveDocumentResponse() async def add_processing( @@ -1918,53 +1889,37 @@ class McpServer: flow: str, collection: str | None = None, tags: List[str] | None = None, + workspace: str | None = None, ctx: Context = None, ) -> AddProcessingResponse: """ Queue a document for processing through a specific workflow. - This tool adds a document to the processing queue where it will be - processed by the specified flow to extract knowledge, create embeddings, - or perform other analysis operations. - Args: processing_id: Unique identifier for this processing job. document_id: ID of the document to process (must exist in library). - flow: Processing flow to use. Different flows perform different - types of analysis (e.g., knowledge extraction, summarization). + flow: Processing flow to use. collection: Target collection for processed knowledge (default: "default"). - Results will be stored under this collection name. tags: Optional tags for categorizing this processing job. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: AddProcessingResponse confirming the document has been queued. - - Use this for: - - Processing uploaded documents into knowledge - - Extracting entities and relationships from text - - Creating searchable embeddings - - Converting documents into structured knowledge - - Note: Processing may take time depending on document size and flow complexity. - Use get_processing to monitor progress. """ if collection is None: collection = "default" if tags is None: tags = [] - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Add processing request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Adding document '{document_id}' to processing queue via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Adding document '{document_id}' to processing queue via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) import time timestamp = int(time.time()) @@ -1981,38 +1936,61 @@ class McpServer: } } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return AddProcessingResponse() + def main(): parser = argparse.ArgumentParser(description='TrustGraph MCP Server') - parser.add_argument('--host', default='0.0.0.0', help='Host to bind to (default: 0.0.0.0)') - parser.add_argument('--port', type=int, default=8000, help='Port to bind to (default: 8000)') - parser.add_argument('--websocket-url', default='ws://api-gateway:8088/api/v1/socket', help='WebSocket URL to connect to (default: ws://api-gateway:8088/api/v1/socket)') + parser.add_argument( + '--host', default='0.0.0.0', + help='Host to bind to (default: 0.0.0.0)', + ) + parser.add_argument( + '--port', type=int, default=8000, + help='Port to bind to (default: 8000)', + ) + parser.add_argument( + '--websocket-url', + default='ws://api-gateway:8088/api/v1/socket', + help='WebSocket URL for the TrustGraph gateway', + ) + parser.add_argument( + '--auth-issuer', + default=os.environ.get("AUTH_ISSUER", ""), + help='OAuth issuer URL for MCP auth metadata discovery', + ) + parser.add_argument( + '--auth-resource-url', + default=os.environ.get("AUTH_RESOURCE_URL", ""), + help='Resource server URL for OAuth protected resource metadata', + ) - # Add logging arguments add_logging_args(parser) args = parser.parse_args() - # Setup logging before creating server setup_logging(vars(args)) - # Read gateway auth token from environment - gateway_token = os.environ.get("GATEWAY_SECRET", "") - - # Create and run the MCP server - server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token) + server = McpServer( + host=args.host, + port=args.port, + websocket_url=args.websocket_url, + auth_issuer=args.auth_issuer, + auth_resource_url=args.auth_resource_url, + ) server.run() + def run(): - """Legacy function for backward compatibility""" main() + if __name__ == "__main__": main() - diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py index bff8ae75..9fbf7459 100644 --- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py +++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py @@ -1,49 +1,110 @@ -from dataclasses import dataclass from websockets.asyncio.client import connect -from urllib.parse import urlencode, urlparse, urlunparse, parse_qs import asyncio import logging import json import uuid -import time +import hashlib + +logger = logging.getLogger(__name__) + + +def _token_key(token): + """Derive a dict key from a token without storing the raw secret.""" + return hashlib.sha256(token.encode()).hexdigest()[:16] + class WebSocketManager: + """Manages an authenticated WebSocket connection to the TrustGraph + gateway on behalf of a single caller. - def __init__(self, url, token=None): + Each caller token gets its own WebSocketManager so that gateway-side + identity, workspace, and capability scoping are preserved end-to-end. + """ + + def __init__(self, url, token): self.url = url + # ── Security boundary: token storage ── + # This is the MCP caller's Bearer token, forwarded verbatim to + # the gateway. It MUST NOT be logged, persisted, or shared + # across callers. It is held only for the lifetime of this + # connection so that re-auth (e.g. after a reconnect) is + # possible. self.token = token self.socket = None - - # FIXME: authentication is broken. The /api/v1/socket endpoint uses - # in-band auth (first-frame protocol via the Mux dispatcher), not - # query-parameter tokens. This query-string token is silently ignored. - # Fix: after connect(), send an auth frame with the bearer token as - # the first message, matching the gateway's in-band auth protocol. - def _build_url(self): - if not self.token: - return self.url - parsed = urlparse(self.url) - params = parse_qs(parsed.query) - params["token"] = [self.token] - new_query = urlencode(params, doseq=True) - return urlunparse(parsed._replace(query=new_query)) + self.identity = None + self.last_used = None async def start(self): - self.socket = await connect(self._build_url()) + """Connect and authenticate via the gateway's in-band auth + protocol. Raises on auth failure.""" + + # ── Security boundary: MCP server → gateway ── + # The WebSocket connects to the gateway and authenticates using + # the caller's Bearer token via the in-band first-frame auth + # protocol. The token belongs to the MCP client — we forward + # it as-is and never interpret its contents. + self.socket = await connect(self.url) self.pending_requests = {} self.running = True + + await self._authenticate() + self.reader_task = asyncio.create_task(self.reader()) + async def _authenticate(self): + """Send in-band auth frame and wait for auth-ok / auth-failed. + + The gateway expects ``{"type": "auth", "token": "..."}`` as the + first frame on a new WebSocket. Any service frame sent before + auth-ok is rejected. + """ + await self.socket.send(json.dumps({ + "type": "auth", + "token": self.token, + })) + + response_text = await asyncio.wait_for(self.socket.recv(), 10) + response = json.loads(response_text) + + if response.get("type") == "auth-ok": + logger.info( + "WebSocket authenticated, default workspace: %s", + response.get("workspace"), + ) + return + + # Auth failed — close immediately, do not leave an + # unauthenticated socket open. + await self.socket.close() + self.socket = None + + if response.get("type") == "auth-failed": + raise RuntimeError( + "Gateway rejected the authentication token" + ) + + raise RuntimeError( + f"Unexpected auth response type: {response.get('type')}" + ) + + async def whoami(self): + """Verify the token by calling the gateway's whoami endpoint. + Returns the identity dict and caches it on ``self.identity``. + """ + gen = self.request("iam", {"operation": "whoami"}, flow_id=None) + async for response in gen: + self.identity = response + return response + async def stop(self): self.running = False - await self.reader_task + if hasattr(self, "reader_task"): + await self.reader_task async def reader(self): - """ - Background task to read websocket responses and route to correct - request - """ + """Background task: read WebSocket frames and route them to the + correct pending-request queue by ``id``.""" while self.running: try: @@ -59,23 +120,21 @@ class WebSocketManager: request_id = response.get("id") if request_id and request_id in self.pending_requests: - # Put the response in the queue queue = self.pending_requests[request_id] await queue.put(response) else: - logging.warning( - f"Response for unknown request ID: {request_id}" + logger.warning( + "Response for unknown request ID: %s", request_id ) except Exception as e: - logging.error(f"Error in websocket reader: {e}") + logger.error("Error in websocket reader: %s", e) - # Put error in all pending queues for queue in self.pending_requests.values(): try: await queue.put({"error": str(e)}) - except: + except Exception: pass self.pending_requests.clear() @@ -86,25 +145,29 @@ class WebSocketManager: async def request( self, service, request_data, flow_id="default", + workspace=None, ): - """ - Send a request via websocket and handle single or streaming responses + """Send a request via WebSocket and yield responses. + + Args: + service: Gateway service name (e.g. "graph-rag", "config"). + request_data: Inner request payload. + flow_id: Optional flow identifier. ``None`` omits the field + (workspace-level services don't use flows). + workspace: Optional workspace override. When ``None`` the + gateway uses the caller's default workspace. """ - # Generate unique request ID + import time + self.last_used = time.monotonic() + request_id = f"{uuid.uuid4()}" - # Determine if this service streams responses - streaming_services = {"agent"} - is_streaming = service in streaming_services - - # Create a queue for all responses (streaming and single) response_queue = asyncio.Queue() self.pending_requests[request_id] = response_queue try: - # Build request message message = { "id": request_id, "service": service, @@ -114,7 +177,16 @@ class WebSocketManager: if flow_id is not None: message["flow"] = flow_id - # Send request + # ── Security boundary: workspace scoping ── + # When the caller supplies a workspace, we set it on the + # message envelope. The gateway's enforce_workspace() + # validates that the authenticated identity is permitted + # to access the target workspace — we MUST NOT skip or + # override that check. When workspace is None, the + # gateway default-fills from the identity's bound workspace. + if workspace is not None: + message["workspace"] = workspace + await self.socket.send(json.dumps(message)) while self.running: @@ -127,19 +199,17 @@ class WebSocketManager: continue if "error" in response: - if "message" in response["error"]: - raise RuntimeError(response["error"]["text"]) + if isinstance(response["error"], dict): + raise RuntimeError( + response["error"].get("message", str(response["error"])) + ) else: raise RuntimeError(str(response["error"])) yield response["response"] - if "complete" in response: - if response["complete"]: - break + if response.get("complete"): + break - except Exception as e: - # Clean up on error + finally: self.pending_requests.pop(request_id, None) - raise e -