From 1f67fc23128edfabf5468a643ae55fa38bca9df3 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 25 Mar 2026 17:53:20 +0000 Subject: [PATCH 01/37] master -> release/v2.2 (#713) Merge doc updates from master into release branch --- README.md | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 61a03cc7..52df2720 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ trustgraph-ai%2Ftrustgraph | Trendshift -# The context backend for reliable AI +# The context development platform -LLMs alone hallucinate and diverge from ground truth. [TrustGraph](https://trustgraph.ai) is a context system that stores, enriches, and delivers context to LLMs to enable reliable AI agents. Think like [Supabase](https://github.com/supabase/supabase) but AI-native and powered by context graphs. +Building applications that need to know things requires more than a database. [TrustGraph](https://trustgraph.ai) is the context development platform: graph-native infrastructure for storing, enriching, and retrieving structured knowledge at any scale. Think like [Supabase](https://github.com/supabase/supabase) but built around context graphs: multi-model storage, semantic retrieval pipelines, portable [context cores](#context-cores), and a full developer toolkit out of the box. Deploy locally or in the cloud. No unnecessary API keys. Just context, engineered. -The context backend: +The platform: - [x] Multi-model and multimodal database system - [x] Tabular/relational, key-value - [x] Document, graph, and vectors @@ -82,8 +82,8 @@ For a browser based quickstart, try the [Configuration Terminal](https://config- Table of Contents
-- [**What is a Context Graph?**](#what-is-a-context-graph)
-- [**Why TrustGraph?**](#why-trustgraph)
+- [**What is a Context Graph?**](#watch-what-is-a-context-graph)
+- [**Context Graphs in Action**](#watch-context-graphs-in-action)
- [**Getting Started**](#getting-started-with-trustgraph)
- [**Context Cores**](#context-cores)
- [**Tech Stack**](#tech-stack)
@@ -94,13 +94,13 @@ For a browser based quickstart, try the [Configuration Terminal](https://config- -## What is a Context Graph? +## Watch What is a Context Graph? [![What is a Context Graph?](https://img.youtube.com/vi/gZjlt5WcWB4/maxresdefault.jpg)](https://www.youtube.com/watch?v=gZjlt5WcWB4) -## Why TrustGraph? +## Watch Context Graphs in Action -[![Why TrustGraph?](https://img.youtube.com/vi/Norboj8YP2M/maxresdefault.jpg)](https://www.youtube.com/watch?v=Norboj8YP2M) +[![Context Graphs in Action with TrustGraph](https://img.youtube.com/vi/sWc7mkhITIo/maxresdefault.jpg)](https://www.youtube.com/watch?v=sWc7mkhITIo) ## Getting Started with TrustGraph @@ -109,10 +109,6 @@ For a browser based quickstart, try the [Configuration Terminal](https://config- - [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference) - [**Deployment Guides**](https://docs.trustgraph.ai/deployment) -### Watch TrustGraph 101 - -[![TrustGraph 101](https://img.youtube.com/vi/rWYl_yhKCng/maxresdefault.jpg)](https://www.youtube.com/watch?v=rWYl_yhKCng) - ## Workbench The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default. From 97f5645ea0724425c0d5cf8babfa732e6224a920 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 26 Mar 2026 09:08:09 +0000 Subject: [PATCH 02/37] CLA (#716) Explanatory text for the CLA process --- docs/contributor-licence-agreement.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 docs/contributor-licence-agreement.md diff --git a/docs/contributor-licence-agreement.md b/docs/contributor-licence-agreement.md new file mode 100644 index 00000000..cc9f5929 --- /dev/null +++ b/docs/contributor-licence-agreement.md @@ -0,0 +1,17 @@ + +# Contributor Licence Agreement (CLA) + +We ask every contributor to sign a lightweight Contributor Licence +Agreement before we can merge a pull request. The CLA does **not** +transfer copyright — you keep full ownership of your work. It simply +grants the TrustGraph project a perpetual, royalty-free licence to +distribute your contribution under the project's +[Apache 2.0 licence](https://www.apache.org/licenses/LICENSE-2.0), and +confirms that you have the right to make the contribution. This +protects both the project and its users by ensuring every contribution +has a clear legal footing. + +When you open your first pull request, the +[CLA Assistant](https://cla-assistant.io/trustgraph-ai/trustgraph) bot +will post a comment asking you to review and sign the agreement — +it only takes a moment and you only need to do it once. From 4164ef1c471342cd5f3f8e2b65a1f15e73b8a022 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 26 Mar 2026 10:49:28 +0000 Subject: [PATCH 03/37] Add GATEWAY_SECRET support for MCP server to API gateway auth (#721) Pass bearer token from GATEWAY_SECRET environment variable as a URL query parameter on websocket connections to the API gateway. When unset or empty, no auth is applied (backwards compatible). --- trustgraph-mcp/trustgraph/mcp_server/mcp.py | 22 ++++++++++++------- .../trustgraph/mcp_server/tg_socket.py | 15 +++++++++++-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index e551ed5d..eadd841b 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager class AppContext: sockets: dict[str, WebSocketManager] websocket_url: str + gateway_token: str @asynccontextmanager -async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]: +async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]: """ Manage application lifecycle with type-safe context @@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8 sockets = {} try: - yield AppContext(sockets=sockets, websocket_url=websocket_url) + yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token) finally: # Cleanup on shutdown @@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user): lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url + gateway_token = lifespan_context.gateway_token if user in sockets: logging.info("Return existing socket manager") return sockets[user] logging.info(f"Opening socket to {websocket_url}...") - + # Create manager with empty pending requests - manager = WebSocketManager(websocket_url) + manager = WebSocketManager(websocket_url, token=gateway_token) # Start reader task with the proper manager await manager.start() @@ -193,13 +195,14 @@ class GetSystemPromptResponse: prompt: str class McpServer: - def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"): + def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""): self.host = host self.port = port self.websocket_url = websocket_url - + self.gateway_token = gateway_token + # Create a partial function to pass websocket_url to app_lifespan - lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url) + lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token) self.mcp = FastMCP( "TrustGraph", dependencies=["trustgraph-base"], @@ -2060,8 +2063,11 @@ def main(): # Setup logging before creating server setup_logging(vars(args)) + # Read gateway auth token from environment + gateway_token = os.environ.get("GATEWAY_SECRET", "") + # Create and run the MCP server - server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url) + server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token) server.run() def run(): diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py index 44f1bf2e..d255ae14 100644 --- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py +++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from websockets.asyncio.client import connect +from urllib.parse import urlencode, urlparse, urlunparse, parse_qs import asyncio import logging import json @@ -9,12 +10,22 @@ import time class WebSocketManager: - def __init__(self, url): + def __init__(self, url, token=None): self.url = url + self.token = token self.socket = None + def _build_url(self): + if not self.token: + return self.url + parsed = urlparse(self.url) + params = parse_qs(parsed.query) + params["token"] = [self.token] + new_query = urlencode(params, doseq=True) + return urlunparse(parsed._replace(query=new_query)) + async def start(self): - self.socket = await connect(self.url) + self.socket = await connect(self._build_url()) self.pending_requests = {} self.running = True self.reader_task = asyncio.create_task(self.reader()) From f02bbdb442be5adb7e171ea37b181fb714703520 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Thu, 26 Mar 2026 14:08:41 +0000 Subject: [PATCH 04/37] New CLA workflow: Uses a github action in trustgraph-ai/contributor-license-agreement This blocks a PR until the commiter responds with a message of agreement with the CLA terms. --- .github/workflows/cla.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .github/workflows/cla.yml diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml new file mode 100644 index 00000000..73f582cf --- /dev/null +++ b/.github/workflows/cla.yml @@ -0,0 +1,25 @@ +name: CLA Assistant + +on: + issue_comment: + types: [created] + pull_request_target: + types: [opened, synchronize, reopened] + +permissions: + actions: write + contents: write + pull-requests: write + statuses: write + +jobs: + CLAssistant: + runs-on: ubuntu-latest + steps: + - name: CLA Assistant + uses: trustgraph-ai/contributor-license-agreement/action@main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PERSONAL_ACCESS_TOKEN: ${{ secrets.CLA_ASSISTANT_PAT }} + with: + allowlist: 'dependabot,dependabot[bot],github-actions,github-actions[bot]' From 1ec081f42fdef508f312493d7ebdd83b5a4efc7f Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 26 Mar 2026 14:18:13 +0000 Subject: [PATCH 05/37] Update CLA notice in repo (#722) --- docs/contributor-licence-agreement.md | 28 ++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/contributor-licence-agreement.md b/docs/contributor-licence-agreement.md index cc9f5929..48314516 100644 --- a/docs/contributor-licence-agreement.md +++ b/docs/contributor-licence-agreement.md @@ -1,17 +1,19 @@ - # Contributor Licence Agreement (CLA) -We ask every contributor to sign a lightweight Contributor Licence -Agreement before we can merge a pull request. The CLA does **not** -transfer copyright — you keep full ownership of your work. It simply -grants the TrustGraph project a perpetual, royalty-free licence to -distribute your contribution under the project's +We ask every contributor to sign a Contributor Licence Agreement before +we can merge a pull request. The CLA does **not** transfer copyright — +you keep full ownership of your work. It simply grants the TrustGraph +project a perpetual, royalty-free licence to distribute your +contribution under the project's [Apache 2.0 licence](https://www.apache.org/licenses/LICENSE-2.0), and -confirms that you have the right to make the contribution. This -protects both the project and its users by ensuring every contribution -has a clear legal footing. +confirms that you have the right to make the contribution. This protects +both the project and its users by ensuring every contribution has a +clear legal footing. + +When you open a pull request, the CLA bot will post a comment asking you +to review and sign the appropriate agreement — it only takes a moment +and you only need to do it once across all TrustGraph repositories. + +- Contributing as an **individual**? Sign the [Individual CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Fiduciary-Contributor-License-Agreement.md) +- Contributing on behalf of a **company or organisation**? Sign the [Entity CLA](https://github.com/trustgraph-ai/contributor-license-agreement/blob/main/Entity-Fiduciary-Contributor-License-Agreement.md) -When you open your first pull request, the -[CLA Assistant](https://cla-assistant.io/trustgraph-ai/trustgraph) bot -will post a comment asking you to review and sign the agreement — -it only takes a moment and you only need to do it once. From 9c55a0a0ff12bdfb1d3ef41065190354830c102b Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 26 Mar 2026 16:46:28 +0000 Subject: [PATCH 06/37] Persistent websocket connections for socket clients and CLI tools (#723) Replace per-request websocket connections in SocketClient and AsyncSocketClient with a single persistent connection that multiplexes requests by ID via a background reader task. This eliminates repeated TCP+WS handshakes which caused significant latency over proxies. Convert show_flows, show_flow_blueprints, and show_parameter_types CLI tools from sequential HTTP requests to concurrent websocket requests using AsyncSocketClient, reducing round trips from O(N) sequential to a small number of parallel batches. Also fix describe_interfaces bug in show_flows where response queue was reading the request field instead of the response field. --- .../trustgraph/api/async_socket_client.py | 207 ++-- .../trustgraph/api/socket_client.py | 946 ++++-------------- .../trustgraph/cli/list_explain_traces.py | 22 +- .../trustgraph/cli/show_flow_blueprints.py | 111 +- trustgraph-cli/trustgraph/cli/show_flows.py | 207 ++-- .../trustgraph/cli/show_parameter_types.py | 228 +++-- 6 files changed, 654 insertions(+), 1067 deletions(-) diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 3279609b..7a239b07 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -1,5 +1,6 @@ import json +import asyncio import websockets from typing import Optional, Dict, Any, AsyncIterator, Union @@ -8,13 +9,29 @@ from . exceptions import ProtocolException, ApplicationException class AsyncSocketClient: - """Asynchronous WebSocket client""" + """Asynchronous WebSocket client with persistent connection. + + Maintains a single websocket connection and multiplexes requests + by ID, routing responses via a background reader task. + + Use as an async context manager for proper lifecycle management: + + async with AsyncSocketClient(url, timeout, token) as client: + result = await client._send_request(...) + + Or call connect()/aclose() manually. + """ def __init__(self, url: str, timeout: int, token: Optional[str]): self.url = self._convert_to_ws_url(url) self.timeout = timeout self.token = token self._request_counter = 0 + self._socket = None + self._connect_cm = None + self._reader_task = None + self._pending = {} # request_id -> asyncio.Queue + self._connected = False def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -25,82 +42,123 @@ class AsyncSocketClient: elif url.startswith("ws://") or url.startswith("wss://"): return url else: - # Assume ws:// return f"ws://{url}" + def _build_ws_url(self): + ws_url = f"{self.url.rstrip('/')}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + return ws_url + + async def connect(self): + """Establish the persistent websocket connection.""" + if self._connected: + return + + ws_url = self._build_ws_url() + self._connect_cm = websockets.connect( + ws_url, ping_interval=20, ping_timeout=self.timeout + ) + self._socket = await self._connect_cm.__aenter__() + self._connected = True + self._reader_task = asyncio.create_task(self._reader()) + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + + async def _ensure_connected(self): + """Lazily connect if not already connected.""" + if not self._connected: + await self.connect() + + async def _reader(self): + """Background task to read responses and route by request ID.""" + try: + async for raw_message in self._socket: + response = json.loads(raw_message) + request_id = response.get("id") + + if request_id and request_id in self._pending: + await self._pending[request_id].put(response) + # Ignore messages for unknown request IDs + + except websockets.exceptions.ConnectionClosed: + pass + except Exception as e: + # Signal error to all pending requests + for queue in self._pending.values(): + try: + await queue.put({"error": str(e)}) + except: + pass + finally: + self._connected = False + + def _next_request_id(self): + self._request_counter += 1 + return f"req-{self._request_counter}" + def flow(self, flow_id: str): """Get async flow instance for WebSocket operations""" return AsyncSocketFlowInstance(self, flow_id) async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]): - """Async WebSocket request implementation (non-streaming)""" - # Generate unique request ID - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Send a request and wait for a single response.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Wait for single response - raw_message = await websocket.recv() - response = json.loads(raw_message) - - if response.get("id") != request_id: - raise ProtocolException(f"Response ID mismatch") + response = await queue.get() if "error" in response: raise ApplicationException(response["error"]) if "response" not in response: - raise ProtocolException(f"Missing response in message") + raise ProtocolException("Missing response in message") return response["response"] + finally: + self._pending.pop(request_id, None) + async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]): - """Async WebSocket request implementation (streaming)""" - # Generate unique request ID - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Send a request and yield streaming response chunks.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Yield chunks as they arrive - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != request_id: - continue # Ignore messages for other requests + while True: + response = await queue.get() if "error" in response: raise ApplicationException(response["error"]) @@ -108,18 +166,16 @@ class AsyncSocketClient: if "response" in response: resp = response["response"] - # Parse different chunk types chunk = self._parse_chunk(resp) - if chunk is not None: # Skip provenance messages in streaming + if chunk is not None: yield chunk - # Check if this is the final message - # end_of_session indicates entire session is complete (including provenance) - # end_of_dialog is for agent dialogs - # complete is from the gateway envelope if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break + finally: + self._pending.pop(request_id, None) + def _parse_chunk(self, resp: Dict[str, Any]): """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") @@ -127,7 +183,6 @@ class AsyncSocketClient: # Handle new GraphRAG message format with message_type if message_type == "provenance": - # Provenance messages are not yielded to user - they're metadata return None if chunk_type == "thought": @@ -147,25 +202,41 @@ class AsyncSocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) elif chunk_type == "action": - # Agent action chunks - treat as thoughts for display purposes return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) ) else: - # RAG-style chunk (or generic chunk with message_type="chunk") - # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None # Errors are always thrown, never stored + error=None ) async def aclose(self): - """Close WebSocket connection""" - # Cleanup handled by context manager - pass + """Close the persistent WebSocket connection cleanly.""" + # Wait for reader to finish (socket close will cause it to exit) + if self._reader_task: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + + # Exit the websockets context manager — this cleanly shuts down + # the connection and its keepalive task + if self._connect_cm: + try: + await self._connect_cm.__aexit__(None, None, None) + except Exception: + pass + self._connect_cm = None + + self._socket = None + self._connected = False + self._pending.clear() class AsyncSocketFlowInstance: @@ -292,7 +363,6 @@ class AsyncSocketFlowInstance: async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" - # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] @@ -362,7 +432,6 @@ class AsyncSocketFlowInstance: limit: int = 10, **kwargs ): """Query row embeddings for semantic search on structured data""" - # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 91db8b69..40769fa0 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -3,6 +3,9 @@ TrustGraph Synchronous WebSocket Client This module provides synchronous WebSocket-based access to TrustGraph services with streaming support for real-time responses from agents, RAG queries, and text completions. + +Uses a persistent WebSocket connection with a background reader task that +multiplexes requests by ID. """ import json @@ -74,43 +77,27 @@ def build_term(value: Any, term_type: Optional[str] = None, class SocketClient: """ - Synchronous WebSocket client for streaming operations. + Synchronous WebSocket client with persistent connection. - Provides a synchronous interface to WebSocket-based TrustGraph services, - wrapping async websockets library with synchronous generators for ease of use. - Supports streaming responses from agents, RAG queries, and text completions. - - Note: This is a synchronous wrapper around async WebSocket operations. For - true async support, use AsyncSocketClient instead. + Maintains a single websocket connection and multiplexes requests + by ID via a background reader task. Provides synchronous generators + for streaming responses. """ def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: - """ - Initialize synchronous WebSocket client. - - Args: - url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS) - timeout: WebSocket timeout in seconds - token: Optional bearer token for authentication - """ self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token - self._connection: Optional[Any] = None self._request_counter: int = 0 self._lock: Lock = Lock() self._loop: Optional[asyncio.AbstractEventLoop] = None + self._socket = None + self._connect_cm = None + self._reader_task = None + self._pending: Dict[str, asyncio.Queue] = {} + self._connected: bool = False def _convert_to_ws_url(self, url: str) -> str: - """ - Convert HTTP URL to WebSocket URL. - - Args: - url: HTTP/HTTPS or WS/WSS URL - - Returns: - str: WebSocket URL (ws:// or wss://) - """ if url.startswith("http://"): return url.replace("http://", "ws://", 1) elif url.startswith("https://"): @@ -118,29 +105,68 @@ class SocketClient: elif url.startswith("ws://") or url.startswith("wss://"): return url else: - # Assume ws:// return f"ws://{url}" + def _build_ws_url(self): + ws_url = f"{self.url.rstrip('/')}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + return ws_url + + def _get_loop(self): + """Get or create the event loop, reusing across calls.""" + if self._loop is None or self._loop.is_closed(): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._loop = loop + return self._loop + + async def _ensure_connected(self): + """Lazily establish the persistent websocket connection.""" + if self._connected: + return + + ws_url = self._build_ws_url() + self._connect_cm = websockets.connect( + ws_url, ping_interval=20, ping_timeout=self.timeout + ) + self._socket = await self._connect_cm.__aenter__() + self._connected = True + self._reader_task = asyncio.create_task(self._reader()) + + async def _reader(self): + """Background task to read responses and route by request ID.""" + try: + async for raw_message in self._socket: + response = json.loads(raw_message) + request_id = response.get("id") + + if request_id and request_id in self._pending: + await self._pending[request_id].put(response) + + except websockets.exceptions.ConnectionClosed: + pass + except Exception as e: + for queue in self._pending.values(): + try: + await queue.put({"error": str(e)}) + except: + pass + finally: + self._connected = False + + def _next_request_id(self): + with self._lock: + self._request_counter += 1 + return f"req-{self._request_counter}" + def flow(self, flow_id: str) -> "SocketFlowInstance": - """ - Get a flow instance for WebSocket streaming operations. - - Args: - flow_id: Flow identifier - - Returns: - SocketFlowInstance: Flow instance with streaming methods - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Stream agent responses - for chunk in flow.agent(question="Hello", user="trustgraph", streaming=True): - print(chunk.content, end='', flush=True) - ``` - """ return SocketFlowInstance(self, flow_id) def _send_request_sync( @@ -152,34 +178,14 @@ class SocketClient: streaming_raw: bool = False, include_provenance: bool = False ) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]: - """Synchronous wrapper around async WebSocket communication. - - Args: - service: Service name - flow: Flow ID (optional) - request: Request payload - streaming: Use parsed streaming (for agent/RAG chunk types) - streaming_raw: Use raw streaming (for data batches like triples) - """ - # Create event loop if needed - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # If loop is running (e.g., in Jupyter), create new loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + """Synchronous wrapper around async WebSocket communication.""" + loop = self._get_loop() if streaming_raw: - # Raw streaming for data batches (triples, rows, etc.) return self._streaming_generator_raw(service, flow, request, loop) elif streaming: - # Parsed streaming for agent/RAG chunk types return self._streaming_generator(service, flow, request, loop, include_provenance) else: - # Non-streaming single response return loop.run_until_complete(self._send_request_async(service, flow, request)) def _streaming_generator( @@ -190,7 +196,7 @@ class SocketClient: loop: asyncio.AbstractEventLoop, include_provenance: bool = False ) -> Iterator[StreamingChunk]: - """Generator that yields streaming chunks (for agent/RAG responses)""" + """Generator that yields streaming chunks.""" async_gen = self._send_request_async_streaming(service, flow, request, include_provenance) try: @@ -201,7 +207,6 @@ class SocketClient: except StopAsyncIteration: break finally: - # Clean up async generator try: loop.run_until_complete(async_gen.aclose()) except: @@ -214,7 +219,7 @@ class SocketClient: request: Dict[str, Any], loop: asyncio.AbstractEventLoop ) -> Iterator[Dict[str, Any]]: - """Generator that yields raw response dicts (for data streaming like triples)""" + """Generator that yields raw response dicts.""" async_gen = self._send_request_async_streaming_raw(service, flow, request) try: @@ -236,35 +241,26 @@ class SocketClient: flow: Optional[str], request: Dict[str, Any] ) -> Iterator[Dict[str, Any]]: - """Async streaming that yields raw response dicts without parsing. + """Async streaming that yields raw response dicts.""" + await self._ensure_connected() - Used for data streaming (triples, rows, etc.) where responses are - just batches of data, not agent/RAG chunk types. - """ - with self._lock: - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + await self._socket.send(json.dumps(message)) - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != request_id: - continue + while True: + response = await queue.get() if "error" in response: raise_from_error_dict(response["error"]) @@ -275,51 +271,46 @@ class SocketClient: if response.get("complete"): break + finally: + self._pending.pop(request_id, None) + async def _send_request_async( self, service: str, flow: Optional[str], request: Dict[str, Any] ) -> Dict[str, Any]: - """Async implementation of WebSocket request (non-streaming)""" - # Generate unique request ID - with self._lock: - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Async non-streaming request over persistent connection.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Wait for single response - raw_message = await websocket.recv() - response = json.loads(raw_message) - - if response.get("id") != request_id: - raise ProtocolException(f"Response ID mismatch") + response = await queue.get() if "error" in response: raise_from_error_dict(response["error"]) if "response" not in response: - raise ProtocolException(f"Missing response in message") + raise ProtocolException("Missing response in message") return response["response"] + finally: + self._pending.pop(request_id, None) + async def _send_request_async_streaming( self, service: str, @@ -327,36 +318,26 @@ class SocketClient: request: Dict[str, Any], include_provenance: bool = False ) -> Iterator[StreamingChunk]: - """Async implementation of WebSocket request (streaming)""" - # Generate unique request ID - with self._lock: - self._request_counter += 1 - request_id = f"req-{self._request_counter}" + """Async streaming request over persistent connection.""" + await self._ensure_connected() - # Build WebSocket URL with optional token - ws_url = f"{self.url}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + request_id = self._next_request_id() + queue = asyncio.Queue() + self._pending[request_id] = queue - # Build request message - message = { - "id": request_id, - "service": service, - "request": request - } - if flow: - message["flow"] = flow + try: + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow - # Connect and send request - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - await websocket.send(json.dumps(message)) + await self._socket.send(json.dumps(message)) - # Yield chunks as they arrive - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != request_id: - continue # Ignore messages for other requests + while True: + response = await queue.get() if "error" in response: raise_from_error_dict(response["error"]) @@ -364,22 +345,19 @@ class SocketClient: if "response" in response: resp = response["response"] - # Check for errors in response chunks if "error" in resp: raise_from_error_dict(resp["error"]) - # Parse different chunk types chunk = self._parse_chunk(resp, include_provenance=include_provenance) - if chunk is not None: # Skip provenance messages unless include_provenance + if chunk is not None: yield chunk - # Check if this is the final message - # end_of_session indicates entire session is complete (including provenance) - # end_of_dialog is for agent dialogs - # complete is from the gateway envelope if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"): break + finally: + self._pending.pop(request_id, None) + def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]: """Parse response chunk into appropriate type. Returns None for non-content messages.""" chunk_type = resp.get("chunk_type") @@ -388,12 +366,10 @@ class SocketClient: # Handle GraphRAG/DocRAG message format with message_type if message_type == "explain": if include_provenance: - # Return provenance event for explainability return ProvenanceEvent( explain_id=resp.get("explain_id", ""), explain_graph=resp.get("explain_graph", "") ) - # Provenance messages are not yielded to user - they're metadata return None # Handle Agent message format with chunk_type="explain" @@ -422,7 +398,6 @@ class SocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) elif chunk_type == "action": - # Agent action chunks - treat as thoughts for display purposes return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) @@ -445,23 +420,43 @@ class SocketClient: end_of_dialog=resp.get("end_of_dialog", False) ) else: - # RAG-style chunk (or generic chunk with message_type="chunk") - # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None # Errors are always thrown, never stored + error=None ) def close(self) -> None: - """ - Close WebSocket connections. + """Close the persistent WebSocket connection.""" + if self._loop and not self._loop.is_closed(): + try: + self._loop.run_until_complete(self._close_async()) + except: + pass - Note: Cleanup is handled automatically by context managers in async code. - """ - # Cleanup handled by context manager in async code - pass + async def _close_async(self): + # Cancel reader task + if self._reader_task: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + + # Exit the websockets context manager — this cleanly shuts down + # the connection and its keepalive task + if self._connect_cm: + try: + await self._connect_cm.__aexit__(None, None, None) + except Exception: + pass + self._connect_cm = None + + self._socket = None + self._connected = False + self._pending.clear() class SocketFlowInstance: @@ -469,18 +464,10 @@ class SocketFlowInstance: Synchronous WebSocket flow instance for streaming operations. Provides the same interface as REST FlowInstance but with WebSocket-based - streaming support for real-time responses. All methods support an optional - `streaming` parameter to enable incremental result delivery. + streaming support for real-time responses. """ def __init__(self, client: SocketClient, flow_id: str) -> None: - """ - Initialize socket flow instance. - - Args: - client: Parent SocketClient - flow_id: Flow identifier - """ self.client: SocketClient = client self.flow_id: str = flow_id @@ -494,44 +481,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[Dict[str, Any], Iterator[StreamingChunk]]: - """ - Execute an agent operation with streaming support. - - Agents can perform multi-step reasoning with tool use. This method always - returns streaming chunks (thoughts, observations, answers) even when - streaming=False, to show the agent's reasoning process. - - Args: - question: User question or instruction - user: User identifier - state: Optional state dictionary for stateful conversations - group: Optional group identifier for multi-user contexts - history: Optional conversation history as list of message dicts - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the agent service - - Returns: - Iterator[StreamingChunk]: Stream of agent thoughts, observations, and answers - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Stream agent reasoning - for chunk in flow.agent( - question="What is quantum computing?", - user="trustgraph", - streaming=True - ): - if isinstance(chunk, AgentThought): - print(f"[Thinking] {chunk.content}") - elif isinstance(chunk, AgentObservation): - print(f"[Observation] {chunk.content}") - elif isinstance(chunk, AgentAnswer): - print(f"[Answer] {chunk.content}") - ``` - """ + """Execute an agent operation with streaming support.""" request = { "question": question, "user": user, @@ -545,8 +495,6 @@ class SocketFlowInstance: request["history"] = history request.update(kwargs) - # Agents always use multipart messaging (multiple complete messages) - # regardless of streaming flag, so always use the streaming code path return self.client._send_request_sync("agent", self.flow_id, request, streaming=True) def agent_explain( @@ -559,70 +507,12 @@ class SocketFlowInstance: history: Optional[List[Dict[str, Any]]] = None, **kwargs: Any ) -> Iterator[Union[StreamingChunk, ProvenanceEvent]]: - """ - Execute an agent operation with explainability support. - - Streams both content chunks (AgentThought, AgentObservation, AgentAnswer) - and provenance events (ProvenanceEvent). Provenance events contain URIs - that can be fetched using ExplainabilityClient to get detailed information - about the agent's reasoning process. - - Agent trace consists of: - - Session: The initial question and session metadata - - Iterations: Each thought/action/observation cycle - - Conclusion: The final answer - - Args: - question: User question or instruction - user: User identifier - collection: Collection identifier for provenance storage - state: Optional state dictionary for stateful conversations - group: Optional group identifier for multi-user contexts - history: Optional conversation history as list of message dicts - **kwargs: Additional parameters passed to the agent service - - Yields: - Union[StreamingChunk, ProvenanceEvent]: Agent chunks and provenance events - - Example: - ```python - from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent - from trustgraph.api import AgentThought, AgentObservation, AgentAnswer - - socket = api.socket() - flow = socket.flow("default") - explain_client = ExplainabilityClient(flow) - - provenance_ids = [] - for item in flow.agent_explain( - question="What is the capital of France?", - user="trustgraph", - collection="default" - ): - if isinstance(item, AgentThought): - print(f"[Thought] {item.content}") - elif isinstance(item, AgentObservation): - print(f"[Observation] {item.content}") - elif isinstance(item, AgentAnswer): - print(f"[Answer] {item.content}") - elif isinstance(item, ProvenanceEvent): - provenance_ids.append(item.explain_id) - - # Fetch session trace after completion - if provenance_ids: - trace = explain_client.fetch_agent_trace( - provenance_ids[0], # Session URI is first - graph="urn:graph:retrieval", - user="trustgraph", - collection="default" - ) - ``` - """ + """Execute an agent operation with explainability support.""" request = { "question": question, "user": user, "collection": collection, - "streaming": True # Always streaming for explain + "streaming": True } if state is not None: request["state"] = state @@ -632,47 +522,13 @@ class SocketFlowInstance: request["history"] = history request.update(kwargs) - # Use streaming with provenance enabled return self.client._send_request_sync( "agent", self.flow_id, request, streaming=True, include_provenance=True ) def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: - """ - Execute text completion with optional streaming. - - Args: - system: System prompt defining the assistant's behavior - prompt: User prompt/question - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Non-streaming - response = flow.text_completion( - system="You are helpful", - prompt="Explain quantum computing", - streaming=False - ) - print(response) - - # Streaming - for chunk in flow.text_completion( - system="You are helpful", - prompt="Explain quantum computing", - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute text completion with optional streaming.""" request = { "system": system, "prompt": prompt, @@ -683,13 +539,11 @@ class SocketFlowInstance: result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming) if streaming: - # For text completion, return generator that yields content return self._text_completion_generator(result) else: return result.get("response", "") def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: - """Generator for text completion streaming""" for chunk in result: if hasattr(chunk, 'content'): yield chunk.content @@ -708,43 +562,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[str, Iterator[str]]: - """ - Execute graph-based RAG query with optional streaming. - - Uses knowledge graph structure to find relevant context, then generates - a response using an LLM. Streaming mode delivers results incrementally. - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - entity_limit: Maximum entities to retrieve (default: 50) - triple_limit: Maximum triples per entity (default: 30) - max_subgraph_size: Maximum total triples in subgraph (default: 1000) - max_path_length: Maximum traversal depth (default: 2) - edge_score_limit: Max edges for semantic pre-filter (default: 50) - edge_limit: Max edges after LLM scoring (default: 25) - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Streaming graph RAG - for chunk in flow.graph_rag( - query="Tell me about Marie Curie", - user="trustgraph", - collection="scientists", - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute graph-based RAG query with optional streaming.""" request = { "query": query, "user": user, @@ -779,61 +597,7 @@ class SocketFlowInstance: edge_limit: int = 25, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: - """ - Execute graph-based RAG query with explainability support. - - Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). - Provenance events contain URIs that can be fetched using ExplainabilityClient - to get detailed information about how the response was generated. - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - entity_limit: Maximum entities to retrieve (default: 50) - triple_limit: Maximum triples per entity (default: 30) - max_subgraph_size: Maximum total triples in subgraph (default: 1000) - max_path_length: Maximum traversal depth (default: 2) - edge_score_limit: Max edges for semantic pre-filter (default: 50) - edge_limit: Max edges after LLM scoring (default: 25) - **kwargs: Additional parameters passed to the service - - Yields: - Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events - - Example: - ```python - from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent - - socket = api.socket() - flow = socket.flow("default") - explain_client = ExplainabilityClient(flow) - - provenance_ids = [] - response_text = "" - - for item in flow.graph_rag_explain( - query="Tell me about Marie Curie", - user="trustgraph", - collection="scientists" - ): - if isinstance(item, RAGChunk): - response_text += item.content - print(item.content, end='', flush=True) - elif isinstance(item, ProvenanceEvent): - provenance_ids.append(item.provenance_id) - - # Fetch explainability details - for prov_id in provenance_ids: - entity = explain_client.fetch_entity( - prov_id, - graph="urn:graph:retrieval", - user="trustgraph", - collection="scientists" - ) - print(f"Entity: {entity}") - ``` - """ + """Execute graph-based RAG query with explainability support.""" request = { "query": query, "user": user, @@ -849,7 +613,6 @@ class SocketFlowInstance: } request.update(kwargs) - # Use streaming with provenance events included return self.client._send_request_sync( "graph-rag", self.flow_id, request, streaming=True, include_provenance=True @@ -864,39 +627,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[str, Iterator[str]]: - """ - Execute document-based RAG query with optional streaming. - - Uses vector embeddings to find relevant document chunks, then generates - a response using an LLM. Streaming mode delivers results incrementally. - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - doc_limit: Maximum document chunks to retrieve (default: 10) - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Streaming document RAG - for chunk in flow.document_rag( - query="Summarize the key findings", - user="trustgraph", - collection="research-papers", - doc_limit=5, - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute document-based RAG query with optional streaming.""" request = { "query": query, "user": user, @@ -921,55 +652,7 @@ class SocketFlowInstance: doc_limit: int = 10, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: - """ - Execute document-based RAG query with explainability support. - - Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent). - Provenance events contain URIs that can be fetched using ExplainabilityClient - to get detailed information about how the response was generated. - - Document RAG trace consists of: - - Question: The user's query - - Exploration: Chunks retrieved from document store (chunk_count) - - Synthesis: The generated answer - - Args: - query: Natural language query - user: User/keyspace identifier - collection: Collection identifier - doc_limit: Maximum document chunks to retrieve (default: 10) - **kwargs: Additional parameters passed to the service - - Yields: - Union[RAGChunk, ProvenanceEvent]: Content chunks and provenance events - - Example: - ```python - from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent - - socket = api.socket() - flow = socket.flow("default") - explain_client = ExplainabilityClient(flow) - - for item in flow.document_rag_explain( - query="Summarize the key findings", - user="trustgraph", - collection="research-papers", - doc_limit=5 - ): - if isinstance(item, RAGChunk): - print(item.content, end='', flush=True) - elif isinstance(item, ProvenanceEvent): - # Fetch entity details - entity = explain_client.fetch_entity( - item.explain_id, - graph=item.explain_graph, - user="trustgraph", - collection="research-papers" - ) - print(f"Event: {entity}", file=sys.stderr) - ``` - """ + """Execute document-based RAG query with explainability support.""" request = { "query": query, "user": user, @@ -980,14 +663,12 @@ class SocketFlowInstance: } request.update(kwargs) - # Use streaming with provenance events included return self.client._send_request_sync( "document-rag", self.flow_id, request, streaming=True, include_provenance=True ) def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: - """Generator for RAG streaming (graph-rag and document-rag)""" for chunk in result: if hasattr(chunk, 'content'): yield chunk.content @@ -999,32 +680,7 @@ class SocketFlowInstance: streaming: bool = False, **kwargs: Any ) -> Union[str, Iterator[str]]: - """ - Execute a prompt template with optional streaming. - - Args: - id: Prompt template identifier - variables: Dictionary of variable name to value mappings - streaming: Enable streaming mode (default: False) - **kwargs: Additional parameters passed to the service - - Returns: - Union[str, Iterator[str]]: Complete response or stream of text chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Streaming prompt execution - for chunk in flow.prompt( - id="summarize-template", - variables={"topic": "quantum computing", "length": "brief"}, - streaming=True - ): - print(chunk, end='', flush=True) - ``` - """ + """Execute a prompt template with optional streaming.""" request = { "id": id, "variables": variables, @@ -1047,33 +703,7 @@ class SocketFlowInstance: limit: int = 10, **kwargs: Any ) -> Dict[str, Any]: - """ - Query knowledge graph entities using semantic similarity. - - Args: - text: Query text for semantic search - user: User/keyspace identifier - collection: Collection identifier - limit: Maximum number of results (default: 10) - **kwargs: Additional parameters passed to the service - - Returns: - dict: Query results with similar entities - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - results = flow.graph_embeddings_query( - text="physicist who discovered radioactivity", - user="trustgraph", - collection="scientists", - limit=5 - ) - ``` - """ - # First convert text to embedding vector + """Query knowledge graph entities using semantic similarity.""" emb_result = self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] @@ -1095,34 +725,7 @@ class SocketFlowInstance: limit: int = 10, **kwargs: Any ) -> Dict[str, Any]: - """ - Query document chunks using semantic similarity. - - Args: - text: Query text for semantic search - user: User/keyspace identifier - collection: Collection identifier - limit: Maximum number of results (default: 10) - **kwargs: Additional parameters passed to the service - - Returns: - dict: Query results with chunk_ids of matching document chunks - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - results = flow.document_embeddings_query( - text="machine learning algorithms", - user="trustgraph", - collection="research-papers", - limit=5 - ) - # results contains {"chunks": [{"chunk_id": "...", "score": 0.95}, ...]} - ``` - """ - # First convert text to embedding vector + """Query document chunks using semantic similarity.""" emb_result = self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] @@ -1137,25 +740,7 @@ class SocketFlowInstance: return self.client._send_request_sync("document-embeddings", self.flow_id, request, False) def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]: - """ - Generate vector embeddings for one or more texts. - - Args: - texts: List of input texts to embed - **kwargs: Additional parameters passed to the service - - Returns: - dict: Response containing vectors (one set per input text) - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - result = flow.embeddings(["quantum computing"]) - vectors = result.get("vectors", []) - ``` - """ + """Generate vector embeddings for one or more texts.""" request = {"texts": texts} request.update(kwargs) @@ -1172,46 +757,9 @@ class SocketFlowInstance: limit: int = 100, **kwargs: Any ) -> List[Dict[str, Any]]: - """ - Query knowledge graph triples using pattern matching. - - Args: - s: Subject filter - URI string, Term dict, or None for wildcard - p: Predicate filter - URI string, Term dict, or None for wildcard - o: Object filter - URI/literal string, Term dict, or None for wildcard - g: Named graph filter - URI string or None for all graphs - user: User/keyspace identifier (optional) - collection: Collection identifier (optional) - limit: Maximum results to return (default: 100) - **kwargs: Additional parameters passed to the service - - Returns: - List[Dict]: List of matching triples in wire format - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Find all triples about a specific subject - triples = flow.triples_query( - s="http://example.org/person/marie-curie", - user="trustgraph", - collection="scientists" - ) - - # Query with named graph filter - triples = flow.triples_query( - s="urn:trustgraph:session:abc123", - g="urn:graph:retrieval", - user="trustgraph", - collection="default" - ) - ``` - """ + """Query knowledge graph triples using pattern matching.""" request = {"limit": limit} - # Build Term dicts for s/p/o (auto-converts strings) s_term = build_term(s) p_term = build_term(p) o_term = build_term(o) @@ -1231,7 +779,6 @@ class SocketFlowInstance: request.update(kwargs) result = self.client._send_request_sync("triples", self.flow_id, request, False) - # Return the triples list from the response if isinstance(result, dict) and "response" in result: return result["response"] return result @@ -1248,46 +795,13 @@ class SocketFlowInstance: batch_size: int = 20, **kwargs: Any ) -> Iterator[List[Dict[str, Any]]]: - """ - Query knowledge graph triples with streaming batches. - - Yields batches of triples as they arrive, reducing time-to-first-result - and memory overhead for large result sets. - - Args: - s: Subject filter - URI string, Term dict, or None for wildcard - p: Predicate filter - URI string, Term dict, or None for wildcard - o: Object filter - URI/literal string, Term dict, or None for wildcard - g: Named graph filter - URI string or None for all graphs - user: User/keyspace identifier (optional) - collection: Collection identifier (optional) - limit: Maximum results to return (default: 100) - batch_size: Triples per batch (default: 20) - **kwargs: Additional parameters passed to the service - - Yields: - List[Dict]: Batches of triples in wire format - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - for batch in flow.triples_query_stream( - user="trustgraph", - collection="default" - ): - for triple in batch: - print(triple["s"], triple["p"], triple["o"]) - ``` - """ + """Query knowledge graph triples with streaming batches.""" request = { "limit": limit, "streaming": True, "batch-size": batch_size, } - # Build Term dicts for s/p/o (auto-converts strings) s_term = build_term(s) p_term = build_term(p) o_term = build_term(o) @@ -1306,9 +820,7 @@ class SocketFlowInstance: request["collection"] = collection request.update(kwargs) - # Use raw streaming - yields response dicts directly without parsing for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True): - # Response is {"response": [...triples...]} from translator if isinstance(response, dict) and "response" in response: yield response["response"] else: @@ -1323,41 +835,7 @@ class SocketFlowInstance: operation_name: Optional[str] = None, **kwargs: Any ) -> Dict[str, Any]: - """ - Execute a GraphQL query against structured rows. - - Args: - query: GraphQL query string - user: User/keyspace identifier - collection: Collection identifier - variables: Optional query variables dictionary - operation_name: Optional operation name for multi-operation documents - **kwargs: Additional parameters passed to the service - - Returns: - dict: GraphQL response with data, errors, and/or extensions - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - query = ''' - { - scientists(limit: 10) { - name - field - discoveries - } - } - ''' - result = flow.rows_query( - query=query, - user="trustgraph", - collection="scientists" - ) - ``` - """ + """Execute a GraphQL query against structured rows.""" request = { "query": query, "user": user, @@ -1377,28 +855,7 @@ class SocketFlowInstance: parameters: Dict[str, Any], **kwargs: Any ) -> Dict[str, Any]: - """ - Execute a Model Context Protocol (MCP) tool. - - Args: - name: Tool name/identifier - parameters: Tool parameters dictionary - **kwargs: Additional parameters passed to the service - - Returns: - dict: Tool execution result - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - result = flow.mcp_tool( - name="search-web", - parameters={"query": "latest AI news", "limit": 5} - ) - ``` - """ + """Execute a Model Context Protocol (MCP) tool.""" request = { "name": name, "parameters": parameters @@ -1417,50 +874,7 @@ class SocketFlowInstance: limit: int = 10, **kwargs: Any ) -> Dict[str, Any]: - """ - Query row data using semantic similarity on indexed fields. - - Finds rows whose indexed field values are semantically similar to the - input text, using vector embeddings. This enables fuzzy/semantic matching - on structured data. - - Args: - text: Query text for semantic search - schema_name: Schema name to search within - user: User/keyspace identifier (default: "trustgraph") - collection: Collection identifier (default: "default") - index_name: Optional index name to filter search to specific index - limit: Maximum number of results (default: 10) - **kwargs: Additional parameters passed to the service - - Returns: - dict: Query results with matches containing index_name, index_value, - text, and score - - Example: - ```python - socket = api.socket() - flow = socket.flow("default") - - # Search for customers by name similarity - results = flow.row_embeddings_query( - text="John Smith", - schema_name="customers", - user="trustgraph", - collection="sales", - limit=5 - ) - - # Filter to specific index - results = flow.row_embeddings_query( - text="machine learning engineer", - schema_name="employees", - index_name="job_title", - limit=10 - ) - ``` - """ - # First convert text to embedding vector + """Query row data using semantic similarity on indexed fields.""" emb_result = self.embeddings(texts=[text]) vector = emb_result.get("vectors", [[]])[0] diff --git a/trustgraph-cli/trustgraph/cli/list_explain_traces.py b/trustgraph-cli/trustgraph/cli/list_explain_traces.py index f545c53f..e6d1e075 100644 --- a/trustgraph-cli/trustgraph/cli/list_explain_traces.py +++ b/trustgraph-cli/trustgraph/cli/list_explain_traces.py @@ -58,6 +58,14 @@ def print_json(sessions): print(json.dumps(sessions, indent=2)) +# Map type names for display +TYPE_DISPLAY = { + "graphrag": "GraphRAG", + "docrag": "DocRAG", + "agent": "Agent", +} + + def main(): parser = argparse.ArgumentParser( prog='tg-list-explain-traces', @@ -118,7 +126,7 @@ def main(): explain_client = ExplainabilityClient(flow) try: - # List all sessions using the API + # List all sessions — uses persistent websocket via SocketClient questions = explain_client.list_sessions( graph=RETRIEVAL_GRAPH, user=args.user, @@ -126,7 +134,8 @@ def main(): limit=args.limit, ) - # Convert to output format + # detect_session_type is mostly a fast URI pattern check, + # only falls back to network calls for unrecognised URIs sessions = [] for q in questions: session_type = explain_client.detect_session_type( @@ -136,16 +145,9 @@ def main(): collection=args.collection ) - # Map type names - type_display = { - "graphrag": "GraphRAG", - "docrag": "DocRAG", - "agent": "Agent", - }.get(session_type, session_type.title()) - sessions.append({ "id": q.uri, - "type": type_display, + "type": TYPE_DISPLAY.get(session_type, session_type.title()), "question": q.query, "time": q.timestamp, }) diff --git a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py index ca1f5c83..8d16d098 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py @@ -3,31 +3,27 @@ Shows all defined flow blueprints. """ import argparse +import asyncio import os import tabulate -from trustgraph.api import Api, ConfigKey +from trustgraph.api import AsyncSocketClient import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def format_parameters(params_metadata, config_api): +def format_parameters(params_metadata, param_type_defs): """ - Format parameter metadata for display + Format parameter metadata for display. - Args: - params_metadata: Parameter definitions from flow blueprint - config_api: API client to get parameter type information - - Returns: - Formatted string describing parameters + param_type_defs is a dict of type_name -> parsed type definition, + pre-fetched concurrently. """ if not params_metadata: return "None" param_list = [] - # Sort parameters by order if available sorted_params = sorted( params_metadata.items(), key=lambda x: x[1].get("order", 999) @@ -37,41 +33,89 @@ def format_parameters(params_metadata, config_api): description = param_meta.get("description", param_name) param_type = param_meta.get("type", "unknown") - # Get type information if available type_info = param_type - if config_api: - try: - key = ConfigKey("parameter-type", param_type) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - - # Add default value if available - default = param_type_def.get("default") - if default is not None: - type_info = f"{param_type} (default: {default})" - - except: - # If we can't get type definition, just show the type name - pass + if param_type in param_type_defs: + param_type_def = param_type_defs[param_type] + default = param_type_def.get("default") + if default is not None: + type_info = f"{param_type} (default: {default})" param_list.append(f" {param_name}: {description} [{type_info}]") return "\n".join(param_list) +async def fetch_data(client): + """Fetch all data needed for show_flow_blueprints concurrently.""" + + # Round 1: list blueprints + resp = await client._send_request("flow", None, { + "operation": "list-blueprints", + }) + blueprint_names = resp.get("blueprint-names", []) + + if not blueprint_names: + return [], {}, {} + + # Round 2: get all blueprints in parallel + blueprint_tasks = [ + client._send_request("flow", None, { + "operation": "get-blueprint", + "blueprint-name": name, + }) + for name in blueprint_names + ] + blueprint_results = await asyncio.gather(*blueprint_tasks) + + blueprints = {} + for name, resp in zip(blueprint_names, blueprint_results): + bp_data = resp.get("blueprint-definition", "{}") + blueprints[name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data + + # Round 3: get all parameter type definitions in parallel + param_types_needed = set() + for bp in blueprints.values(): + for param_meta in bp.get("parameters", {}).values(): + pt = param_meta.get("type", "") + if pt: + param_types_needed.add(pt) + + param_type_defs = {} + if param_types_needed: + param_type_tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": pt}], + }) + for pt in param_types_needed + ] + param_type_results = await asyncio.gather(*param_type_tasks) + + for pt, resp in zip(param_types_needed, param_type_results): + values = resp.get("values", []) + if values: + try: + param_type_defs[pt] = json.loads(values[0].get("value", "{}")) + except (json.JSONDecodeError, AttributeError): + pass + + return blueprint_names, blueprints, param_type_defs + +async def _show_flow_blueprints_async(url, token=None): + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_data(client) + def show_flow_blueprints(url, token=None): - api = Api(url, token=token) - flow_api = api.flow() - config_api = api.config() + blueprint_names, blueprints, param_type_defs = asyncio.run( + _show_flow_blueprints_async(url, token=token) + ) - blueprint_names = flow_api.list_blueprints() - - if len(blueprint_names) == 0: + if not blueprint_names: print("No flow blueprints.") return for blueprint_name in blueprint_names: - cls = flow_api.get_blueprint(blueprint_name) + cls = blueprints[blueprint_name] table = [] table.append(("name", blueprint_name)) @@ -81,10 +125,9 @@ def show_flow_blueprints(url, token=None): if tags: table.append(("tags", ", ".join(tags))) - # Show parameters if they exist parameters = cls.get("parameters", {}) if parameters: - param_str = format_parameters(parameters, config_api) + param_str = format_parameters(parameters, param_type_defs) table.append(("parameters", param_str)) print(tabulate.tabulate( diff --git a/trustgraph-cli/trustgraph/cli/show_flows.py b/trustgraph-cli/trustgraph/cli/show_flows.py index 828c18f1..d1abf984 100644 --- a/trustgraph-cli/trustgraph/cli/show_flows.py +++ b/trustgraph-cli/trustgraph/cli/show_flows.py @@ -3,22 +3,15 @@ Shows configured flows. """ import argparse +import asyncio import os import tabulate -from trustgraph.api import Api, ConfigKey +from trustgraph.api import Api, AsyncSocketClient import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def get_interface(config_api, i): - - key = ConfigKey("interface-description", i) - - value = config_api.get([key])[0].value - - return json.loads(value) - def describe_interfaces(intdefs, flow): intfs = flow.get("interfaces", {}) @@ -34,7 +27,7 @@ def describe_interfaces(intdefs, flow): if kind == "request-response": req = intfs[k]["request"] - resp = intfs[k]["request"] + resp = intfs[k]["response"] lst.append(f"{k} request: {req}") lst.append(f"{k} response: {resp}") @@ -49,17 +42,9 @@ def describe_interfaces(intdefs, flow): def get_enum_description(param_value, param_type_def): """ Get the human-readable description for an enum value - - Args: - param_value: The actual parameter value (e.g., "gpt-4") - param_type_def: The parameter type definition containing enum objects - - Returns: - Human-readable description or the original value if not found """ enum_list = param_type_def.get("enum", []) - # Handle both old format (strings) and new format (objects with id/description) for enum_item in enum_list: if isinstance(enum_item, dict): if enum_item.get("id") == param_value: @@ -67,27 +52,20 @@ def get_enum_description(param_value, param_type_def): elif enum_item == param_value: return param_value - # If not found in enum, return original value return param_value -def format_parameters(flow_params, blueprint_params_metadata, config_api): +def format_parameters(flow_params, blueprint_params_metadata, param_type_defs): """ - Format flow parameters with their human-readable descriptions + Format flow parameters with their human-readable descriptions. - Args: - flow_params: The actual parameter values used in the flow - blueprint_params_metadata: The parameter metadata from the flow blueprint definition - config_api: API client to retrieve parameter type definitions - - Returns: - Formatted string of parameters with descriptions + param_type_defs is a dict of type_name -> parsed type definition, + pre-fetched concurrently. """ if not flow_params: return "None" param_list = [] - # Sort parameters by order if available sorted_params = sorted( blueprint_params_metadata.items(), key=lambda x: x[1].get("order", 999) @@ -100,80 +78,165 @@ def format_parameters(flow_params, blueprint_params_metadata, config_api): param_type = param_meta.get("type", "") controlled_by = param_meta.get("controlled-by", None) - # Try to get enum description if this parameter has a type definition display_value = value - if param_type and config_api: - try: - from trustgraph.api import ConfigKey - key = ConfigKey("parameter-type", param_type) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - display_value = get_enum_description(value, param_type_def) - except: - # If we can't get the type definition, just use the original value - display_value = value + if param_type and param_type in param_type_defs: + display_value = get_enum_description( + value, param_type_defs[param_type] + ) - # Format the parameter line line = f"• {description}: {display_value}" - # Add controlled-by indicator if present if controlled_by: line += f" (controlled by {controlled_by})" param_list.append(line) - # Add any parameters that aren't in the blueprint metadata (shouldn't happen normally) for param_name, value in flow_params.items(): if param_name not in blueprint_params_metadata: param_list.append(f"• {param_name}: {value} (undefined)") return "\n".join(param_list) if param_list else "None" +async def fetch_show_flows(client): + """Fetch all data needed for show_flows concurrently.""" + + # Round 1: list interfaces and list flows in parallel + interface_names_resp, flow_ids_resp = await asyncio.gather( + client._send_request("config", None, { + "operation": "list", + "type": "interface-description", + }), + client._send_request("flow", None, { + "operation": "list-flows", + }), + ) + + interface_names = interface_names_resp.get("directory", []) + flow_ids = flow_ids_resp.get("flow-ids", []) + + if not flow_ids: + return {}, [], {}, {} + + # Round 2: get all interfaces + all flows in parallel + interface_tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "interface-description", "key": name}], + }) + for name in interface_names + ] + + flow_tasks = [ + client._send_request("flow", None, { + "operation": "get-flow", + "flow-id": fid, + }) + for fid in flow_ids + ] + + results = await asyncio.gather(*interface_tasks, *flow_tasks) + + # Split results + interface_results = results[:len(interface_names)] + flow_results = results[len(interface_names):] + + # Parse interfaces + interface_defs = {} + for name, resp in zip(interface_names, interface_results): + values = resp.get("values", []) + if values: + interface_defs[name] = json.loads(values[0].get("value", "{}")) + + # Parse flows + flows = {} + for fid, resp in zip(flow_ids, flow_results): + flow_data = resp.get("flow", "{}") + flows[fid] = json.loads(flow_data) if isinstance(flow_data, str) else flow_data + + # Round 3: get all blueprints in parallel + blueprint_names = set() + for flow in flows.values(): + bp = flow.get("blueprint-name", "") + if bp: + blueprint_names.add(bp) + + blueprint_tasks = [ + client._send_request("flow", None, { + "operation": "get-blueprint", + "blueprint-name": bp_name, + }) + for bp_name in blueprint_names + ] + + blueprint_results = await asyncio.gather(*blueprint_tasks) + + blueprints = {} + for bp_name, resp in zip(blueprint_names, blueprint_results): + bp_data = resp.get("blueprint-definition", "{}") + blueprints[bp_name] = json.loads(bp_data) if isinstance(bp_data, str) else bp_data + + # Round 4: get all parameter type definitions in parallel + param_types_needed = set() + for bp in blueprints.values(): + for param_meta in bp.get("parameters", {}).values(): + pt = param_meta.get("type", "") + if pt: + param_types_needed.add(pt) + + param_type_tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": pt}], + }) + for pt in param_types_needed + ] + + param_type_results = await asyncio.gather(*param_type_tasks) + + param_type_defs = {} + for pt, resp in zip(param_types_needed, param_type_results): + values = resp.get("values", []) + if values: + try: + param_type_defs[pt] = json.loads(values[0].get("value", "{}")) + except (json.JSONDecodeError, AttributeError): + pass + + return interface_defs, flow_ids, flows, blueprints, param_type_defs + +async def _show_flows_async(url, token=None): + + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_show_flows(client) + def show_flows(url, token=None): - api = Api(url, token=token) - config_api = api.config() - flow_api = api.flow() + result = asyncio.run(_show_flows_async(url, token=token)) - interface_names = config_api.list("interface-description") + interface_defs, flow_ids, flows, blueprints, param_type_defs = result - interface_defs = { - i: get_interface(config_api, i) - for i in interface_names - } - - flow_ids = flow_api.list() - - if len(flow_ids) == 0: + if not flow_ids: print("No flows.") return - flows = [] + for fid in flow_ids: - for id in flow_ids: - - flow = flow_api.get(id) + flow = flows[fid] table = [] - table.append(("id", id)) + table.append(("id", fid)) table.append(("blueprint", flow.get("blueprint-name", ""))) table.append(("desc", flow.get("description", ""))) - # Display parameters with human-readable descriptions parameters = flow.get("parameters", {}) if parameters: - # Try to get the flow blueprint definition for parameter metadata blueprint_name = flow.get("blueprint-name", "") - if blueprint_name: - try: - flow_blueprint = flow_api.get_blueprint(blueprint_name) - blueprint_params_metadata = flow_blueprint.get("parameters", {}) - param_str = format_parameters(parameters, blueprint_params_metadata, config_api) - except Exception as e: - # Fallback to JSON if we can't get the blueprint definition - param_str = json.dumps(parameters, indent=2) + if blueprint_name and blueprint_name in blueprints: + blueprint_params_metadata = blueprints[blueprint_name].get("parameters", {}) + param_str = format_parameters( + parameters, blueprint_params_metadata, param_type_defs + ) else: - # No blueprint name, fallback to JSON param_str = json.dumps(parameters, indent=2) table.append(("parameters", param_str)) @@ -220,4 +283,4 @@ def main(): print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/show_parameter_types.py b/trustgraph-cli/trustgraph/cli/show_parameter_types.py index 2e0f1be3..67d6e823 100644 --- a/trustgraph-cli/trustgraph/cli/show_parameter_types.py +++ b/trustgraph-cli/trustgraph/cli/show_parameter_types.py @@ -7,9 +7,10 @@ valid enums, and validation rules. """ import argparse +import asyncio import os import tabulate -from trustgraph.api import Api, ConfigKey +from trustgraph.api import AsyncSocketClient import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') @@ -17,13 +18,7 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def format_enum_values(enum_list): """ - Format enum values for display, handling both old and new formats - - Args: - enum_list: List of enum values (strings or objects with id/description) - - Returns: - Formatted string describing enum options + Format enum values for display, handling both old and new formats. """ if not enum_list: return "Any value" @@ -31,7 +26,6 @@ def format_enum_values(enum_list): enum_items = [] for item in enum_list: if isinstance(item, dict): - # New format: objects with id and description enum_id = item.get("id", "") description = item.get("description", "") if description: @@ -39,99 +33,146 @@ def format_enum_values(enum_list): else: enum_items.append(enum_id) else: - # Old format: simple strings enum_items.append(str(item)) return "\n".join(f"• {item}" for item in enum_items) def format_constraints(param_type_def): """ - Format validation constraints for display - - Args: - param_type_def: Parameter type definition - - Returns: - Formatted string describing constraints + Format validation constraints for display. """ constraints = [] - # Handle numeric constraints if "minimum" in param_type_def: constraints.append(f"min: {param_type_def['minimum']}") if "maximum" in param_type_def: constraints.append(f"max: {param_type_def['maximum']}") - - # Handle string constraints if "minLength" in param_type_def: constraints.append(f"min length: {param_type_def['minLength']}") if "maxLength" in param_type_def: constraints.append(f"max length: {param_type_def['maxLength']}") if "pattern" in param_type_def: constraints.append(f"pattern: {param_type_def['pattern']}") - - # Handle required field if param_type_def.get("required", False): constraints.append("required") return ", ".join(constraints) if constraints else "None" +def format_param_type(param_type_name, param_type_def): + """Format a single parameter type for display.""" + table = [] + table.append(("name", param_type_name)) + table.append(("description", param_type_def.get("description", ""))) + table.append(("type", param_type_def.get("type", "unknown"))) + + default = param_type_def.get("default") + if default is not None: + table.append(("default", str(default))) + + enum_list = param_type_def.get("enum") + if enum_list: + enum_str = format_enum_values(enum_list) + table.append(("valid values", enum_str)) + + constraints = format_constraints(param_type_def) + if constraints != "None": + table.append(("constraints", constraints)) + + return table + +async def fetch_all_param_types(client): + """Fetch all parameter types concurrently.""" + + # Round 1: list parameter types + resp = await client._send_request("config", None, { + "operation": "list", + "type": "parameter-type", + }) + param_type_names = resp.get("directory", []) + + if not param_type_names: + return [], {} + + # Round 2: get all parameter types in parallel + tasks = [ + client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": name}], + }) + for name in param_type_names + ] + results = await asyncio.gather(*tasks) + + param_type_defs = {} + for name, resp in zip(param_type_names, results): + values = resp.get("values", []) + if values: + try: + param_type_defs[name] = json.loads(values[0].get("value", "{}")) + except (json.JSONDecodeError, AttributeError): + pass + + return param_type_names, param_type_defs + +async def fetch_single_param_type(client, param_type_name): + """Fetch a single parameter type.""" + resp = await client._send_request("config", None, { + "operation": "get", + "keys": [{"type": "parameter-type", "key": param_type_name}], + }) + values = resp.get("values", []) + if values: + return json.loads(values[0].get("value", "{}")) + return None + def show_parameter_types(url, token=None): - """ - Show all parameter type definitions - """ - api = Api(url, token=token) - config_api = api.config() + """Show all parameter type definitions.""" - # Get list of all parameter types - try: - param_type_names = config_api.list("parameter-type") - except Exception as e: - print(f"Error retrieving parameter types: {e}") - return + async def _fetch(): + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_all_param_types(client) - if len(param_type_names) == 0: + param_type_names, param_type_defs = asyncio.run(_fetch()) + + if not param_type_names: print("No parameter types defined.") return - for param_type_name in param_type_names: - try: - # Get the parameter type definition - key = ConfigKey("parameter-type", param_type_name) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - - table = [] - table.append(("name", param_type_name)) - table.append(("description", param_type_def.get("description", ""))) - table.append(("type", param_type_def.get("type", "unknown"))) - - # Show default value if present - default = param_type_def.get("default") - if default is not None: - table.append(("default", str(default))) - - # Show enum values if present - enum_list = param_type_def.get("enum") - if enum_list: - enum_str = format_enum_values(enum_list) - table.append(("valid values", enum_str)) - - # Show constraints - constraints = format_constraints(param_type_def) - if constraints != "None": - table.append(("constraints", constraints)) - - print(tabulate.tabulate( - table, - tablefmt="pretty", - stralign="left", - )) + for name in param_type_names: + if name not in param_type_defs: + print(f"Error retrieving parameter type '{name}'") print() + continue - except Exception as e: - print(f"Error retrieving parameter type '{param_type_name}': {e}") - print() + table = format_param_type(name, param_type_defs[name]) + + print(tabulate.tabulate( + table, + tablefmt="pretty", + stralign="left", + )) + print() + +def show_specific_parameter_type(url, param_type_name, token=None): + """Show a specific parameter type definition.""" + + async def _fetch(): + async with AsyncSocketClient(url, timeout=60, token=token) as client: + return await fetch_single_param_type(client, param_type_name) + + param_type_def = asyncio.run(_fetch()) + + if param_type_def is None: + print(f"Error retrieving parameter type '{param_type_name}'") + return + + table = format_param_type(param_type_name, param_type_def) + + print(tabulate.tabulate( + table, + tablefmt="pretty", + stralign="left", + )) def main(): parser = argparse.ArgumentParser( @@ -161,57 +202,12 @@ def main(): try: if args.type: - # Show specific parameter type show_specific_parameter_type(args.api_url, args.type, args.token) else: - # Show all parameter types show_parameter_types(args.api_url, args.token) except Exception as e: print("Exception:", e, flush=True) -def show_specific_parameter_type(url, param_type_name, token=None): - """ - Show a specific parameter type definition - """ - api = Api(url, token=token) - config_api = api.config() - - try: - # Get the parameter type definition - key = ConfigKey("parameter-type", param_type_name) - type_def_value = config_api.get([key])[0].value - param_type_def = json.loads(type_def_value) - - table = [] - table.append(("name", param_type_name)) - table.append(("description", param_type_def.get("description", ""))) - table.append(("type", param_type_def.get("type", "unknown"))) - - # Show default value if present - default = param_type_def.get("default") - if default is not None: - table.append(("default", str(default))) - - # Show enum values if present - enum_list = param_type_def.get("enum") - if enum_list: - enum_str = format_enum_values(enum_list) - table.append(("valid values", enum_str)) - - # Show constraints - constraints = format_constraints(param_type_def) - if constraints != "None": - table.append(("constraints", constraints)) - - print(tabulate.tabulate( - table, - tablefmt="pretty", - stralign="left", - )) - - except Exception as e: - print(f"Error retrieving parameter type '{param_type_name}': {e}") - if __name__ == "__main__": - main() \ No newline at end of file + main() From ea33620fb2cd6ce92783c271a33056d45a1de29b Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 26 Mar 2026 16:58:30 +0000 Subject: [PATCH 07/37] Fix missing auth header in verify_system_status (#724) Fix missing auth header in verify_system_status processor check The check_processors function received the token parameter but did not include it in the Authorization header when calling the metrics endpoint, causing 401 errors when gateway auth is enabled. --- trustgraph-cli/trustgraph/cli/verify_system_status.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py index 8cebc83f..5fea1bb0 100644 --- a/trustgraph-cli/trustgraph/cli/verify_system_status.py +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -178,7 +178,11 @@ def check_processors(url: str, min_processors: int, timeout: int, token: Optiona url += '/' metrics_url = f"{url}api/metrics/query?query=processor_info" - resp = requests.get(metrics_url, timeout=timeout) + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.get(metrics_url, timeout=timeout, headers=headers) if resp.status_code == 200: data = resp.json() processor_count = len(data.get("data", {}).get("result", [])) From a63452050976f1f125ee4d9c242b01a5c3e32547 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sat, 28 Mar 2026 10:58:28 +0000 Subject: [PATCH 08/37] Fix websocket error responses in Mux dispatcher (#726) Error responses from the websocket multiplexer were missing the request ID and using a bare string format instead of the structured error protocol. This caused clients to hang when a request failed (e.g. unsupported service for a flow) because the error could not be routed to the waiting caller. Include request ID in all error paths, use structured error format ({message, type}) with complete flag, and extract the ID early in receive() so even malformed requests get a routable error when possible. Updated tests - tests were coded against invalid protocol messages --- tests/unit/test_gateway/test_dispatch_mux.py | 32 +++++++++++------- .../trustgraph/gateway/dispatch/mux.py | 33 +++++++++++++++---- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/tests/unit/test_gateway/test_dispatch_mux.py b/tests/unit/test_gateway/test_dispatch_mux.py index b623a1b6..a0bc9460 100644 --- a/tests/unit/test_gateway/test_dispatch_mux.py +++ b/tests/unit/test_gateway/test_dispatch_mux.py @@ -121,7 +121,11 @@ class TestMux: # Based on the code, it seems to catch exceptions await mux.receive(mock_msg) - mock_ws.send_json.assert_called_once_with({"error": "Bad message"}) + mock_ws.send_json.assert_called_once_with({ + "error": {"message": "Bad message", "type": "error"}, + "complete": True, + "id": "test-id-123", + }) @pytest.mark.asyncio async def test_mux_receive_message_without_id(self): @@ -129,23 +133,26 @@ class TestMux: mock_dispatcher_manager = MagicMock() mock_ws = AsyncMock() mock_running = MagicMock() - + mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, running=mock_running ) - + # Mock message without id field mock_msg = MagicMock() mock_msg.json.return_value = { "request": {"type": "test"} } - + # receive method should handle the RuntimeError internally await mux.receive(mock_msg) - - mock_ws.send_json.assert_called_once_with({"error": "Bad message"}) + + mock_ws.send_json.assert_called_once_with({ + "error": {"message": "Bad message", "type": "error"}, + "complete": True, + }) @pytest.mark.asyncio async def test_mux_receive_invalid_json(self): @@ -153,19 +160,22 @@ class TestMux: mock_dispatcher_manager = MagicMock() mock_ws = AsyncMock() mock_running = MagicMock() - + mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, running=mock_running ) - + # Mock message with invalid JSON mock_msg = MagicMock() mock_msg.json.side_effect = ValueError("Invalid JSON") - + # receive method should handle the ValueError internally await mux.receive(mock_msg) - + mock_msg.json.assert_called_once() - mock_ws.send_json.assert_called_once_with({"error": "Invalid JSON"}) \ No newline at end of file + mock_ws.send_json.assert_called_once_with({ + "error": {"message": "Invalid JSON", "type": "error"}, + "complete": True, + }) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index ddaa8ddf..fabd5c44 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -33,9 +33,12 @@ class Mux: async def receive(self, msg): + request_id = None + try: data = msg.json() + request_id = data.get("id") if "request" not in data: raise RuntimeError("Bad message") @@ -51,7 +54,13 @@ class Mux: except Exception as e: logger.error(f"Receive exception: {str(e)}", exc_info=True) - await self.ws.send_json({"error": str(e)}) + error_resp = { + "error": {"message": str(e), "type": "error"}, + "complete": True, + } + if request_id: + error_resp["id"] = request_id + await self.ws.send_json(error_resp) async def maybe_tidy_workers(self, workers): @@ -97,12 +106,12 @@ class Mux: }) worker = asyncio.create_task( - self.request_task(request, responder, flow, svc) + self.request_task(id, request, responder, flow, svc) ) workers.append(worker) - async def request_task(self, request, responder, flow, svc): + async def request_task(self, id, request, responder, flow, svc): try: @@ -119,7 +128,11 @@ class Mux: ) except Exception as e: - await self.ws.send_json({"error": str(e)}) + await self.ws.send_json({ + "id": id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) async def run(self): @@ -143,7 +156,11 @@ class Mux: except Exception as e: # This is an internal working error, may not be recoverable logger.error(f"Run prepare exception: {e}", exc_info=True) - await self.ws.send_json({"id": id, "error": str(e)}) + await self.ws.send_json({ + "id": id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) self.running.stop() if self.ws: @@ -160,7 +177,11 @@ class Mux: except Exception as e: logger.error(f"Exception in mux: {e}", exc_info=True) - await self.ws.send_json({"error": str(e)}) + await self.ws.send_json({ + "id": id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) self.running.stop() From 20204d87c3d345360f59611638461aff3dfc9e4e Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sat, 28 Mar 2026 11:19:45 +0000 Subject: [PATCH 09/37] Fix OpenAI compatibility issues for newer models and Azure config (#727) Use max_completion_tokens for OpenAI and Azure OpenAI providers: The OpenAI API deprecated max_tokens in favor of max_completion_tokens for chat completions. Newer models (gpt-4o, o1, o3) reject the old parameter with a 400 error. AZURE_API_VERSION env var now overrides the default API version: (falls back to 2024-12-01-preview). Update tests to test for expected structures --- tests/integration/test_text_completion_integration.py | 8 ++++---- .../test_text_completion_streaming_integration.py | 2 +- .../test_text_completion/test_azure_openai_processor.py | 4 ++-- tests/unit/test_text_completion/test_openai_processor.py | 4 ++-- .../trustgraph/model/text_completion/azure_openai/llm.py | 6 +++--- .../trustgraph/model/text_completion/openai/llm.py | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/integration/test_text_completion_integration.py b/tests/integration/test_text_completion_integration.py index 08e2a995..26fa0ab3 100644 --- a/tests/integration/test_text_completion_integration.py +++ b/tests/integration/test_text_completion_integration.py @@ -93,7 +93,7 @@ class TestTextCompletionIntegration: assert call_args.kwargs['model'] == "gpt-3.5-turbo" assert call_args.kwargs['temperature'] == 0.7 - assert call_args.kwargs['max_tokens'] == 1024 + assert call_args.kwargs['max_completion_tokens'] == 1024 assert len(call_args.kwargs['messages']) == 1 assert call_args.kwargs['messages'][0]['role'] == "user" assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text'] @@ -134,7 +134,7 @@ class TestTextCompletionIntegration: call_args = mock_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == config['model'] assert call_args.kwargs['temperature'] == config['temperature'] - assert call_args.kwargs['max_tokens'] == config['max_output'] + assert call_args.kwargs['max_completion_tokens'] == config['max_output'] # Reset mock for next iteration mock_openai_client.reset_mock() @@ -286,7 +286,7 @@ class TestTextCompletionIntegration: # were removed in #561 as unnecessary parameters assert 'model' in call_args.kwargs assert 'temperature' in call_args.kwargs - assert 'max_tokens' in call_args.kwargs + assert 'max_completion_tokens' in call_args.kwargs # Verify result structure assert hasattr(result, 'text') @@ -362,7 +362,7 @@ class TestTextCompletionIntegration: call_args = mock_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == "gpt-4" assert call_args.kwargs['temperature'] == 0.8 - assert call_args.kwargs['max_tokens'] == 2048 + assert call_args.kwargs['max_completion_tokens'] == 2048 # Note: top_p, frequency_penalty, and presence_penalty # were removed in #561 as unnecessary parameters diff --git a/tests/integration/test_text_completion_streaming_integration.py b/tests/integration/test_text_completion_streaming_integration.py index a70afb4c..6968affa 100644 --- a/tests/integration/test_text_completion_streaming_integration.py +++ b/tests/integration/test_text_completion_streaming_integration.py @@ -201,7 +201,7 @@ class TestTextCompletionStreaming: call_args = mock_streaming_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == "gpt-4" assert call_args.kwargs['temperature'] == 0.5 - assert call_args.kwargs['max_tokens'] == 2048 + assert call_args.kwargs['max_completion_tokens'] == 2048 assert call_args.kwargs['stream'] is True # Verify chunks have correct model diff --git a/tests/unit/test_text_completion/test_azure_openai_processor.py b/tests/unit/test_text_completion/test_azure_openai_processor.py index 02c85d54..b00531ad 100644 --- a/tests/unit/test_text_completion/test_azure_openai_processor.py +++ b/tests/unit/test_text_completion/test_azure_openai_processor.py @@ -108,7 +108,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase): }] }], temperature=0.0, - max_tokens=4192, + max_completion_tokens=4192, top_p=1 ) @@ -399,7 +399,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase): # Verify other parameters assert call_args[1]['model'] == 'gpt-4' assert call_args[1]['temperature'] == 0.5 - assert call_args[1]['max_tokens'] == 1024 + assert call_args[1]['max_completion_tokens'] == 1024 assert call_args[1]['top_p'] == 1 @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') diff --git a/tests/unit/test_text_completion/test_openai_processor.py b/tests/unit/test_text_completion/test_openai_processor.py index 352af062..514d35da 100644 --- a/tests/unit/test_text_completion/test_openai_processor.py +++ b/tests/unit/test_text_completion/test_openai_processor.py @@ -102,7 +102,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase): }] }], temperature=0.0, - max_tokens=4096 + max_completion_tokens=4096 ) @patch('trustgraph.model.text_completion.openai.llm.OpenAI') @@ -380,7 +380,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase): # Verify other parameters assert call_args[1]['model'] == 'gpt-3.5-turbo' assert call_args[1]['temperature'] == 0.5 - assert call_args[1]['max_tokens'] == 1024 + assert call_args[1]['max_completion_tokens'] == 1024 @patch('trustgraph.model.text_completion.openai.llm.OpenAI') diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py index 4ab0b302..def44bd4 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -20,7 +20,7 @@ default_ident = "text-completion" default_temperature = 0.0 default_max_output = 4192 -default_api = "2024-12-01-preview" +default_api = os.getenv("AZURE_API_VERSION", "2024-12-01-preview") default_endpoint = os.getenv("AZURE_ENDPOINT", None) default_token = os.getenv("AZURE_TOKEN", None) default_model = os.getenv("AZURE_MODEL", None) @@ -90,7 +90,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, top_p=1, ) @@ -159,7 +159,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, top_p=1, stream=True, stream_options={"include_usage": True} diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index d65e27bf..cdc8602a 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -86,7 +86,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, ) inputtokens = resp.usage.prompt_tokens @@ -152,7 +152,7 @@ class Processor(LlmService): } ], temperature=effective_temperature, - max_tokens=self.max_output, + max_completion_tokens=self.max_output, stream=True, stream_options={"include_usage": True} ) From 413f9176769c186309cd70164884711866d3a7b5 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sun, 29 Mar 2026 20:22:06 +0100 Subject: [PATCH 10/37] Add missing pdf extra to unstructured dependency (#728) * Fix PDF processing deps so that PDF processing works --- containers/Containerfile.unstructured | 7 ++++++- trustgraph-unstructured/pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/containers/Containerfile.unstructured b/containers/Containerfile.unstructured index 22ee05b2..7284901e 100644 --- a/containers/Containerfile.unstructured +++ b/containers/Containerfile.unstructured @@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.13 && \ +RUN dnf install -y python3.13 libxcb mesa-libGL && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ @@ -38,6 +38,11 @@ RUN ls /root/wheels FROM base +# Pre-install CPU-only PyTorch so that unstructured[pdf]'s torch +# dependency is satisfied without pulling in CUDA (~190MB vs ~2GB+) +RUN pip3 install --no-cache-dir torch==2.11.0+cpu \ + --index-url https://download.pytorch.org/whl/cpu + COPY --from=build /root/wheels /root/wheels RUN \ diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml index 35597398..33265edb 100644 --- a/trustgraph-unstructured/pyproject.toml +++ b/trustgraph-unstructured/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "pulsar-client", "prometheus-client", "python-magic", - "unstructured[csv,docx,epub,md,odt,pptx,rst,rtf,tsv,xlsx]", + "unstructured[csv,docx,epub,md,odt,pdf,pptx,rst,rtf,tsv,xlsx]", ] classifiers = [ "Programming Language :: Python :: 3", From 687a9e08fe4d7ef404210f11087deef995589599 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Sun, 29 Mar 2026 20:26:26 +0100 Subject: [PATCH 11/37] master -> release/v2.2 (#732) --- README.md | 2 +- SECURITY.md | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 SECURITY.md diff --git a/README.md b/README.md index 52df2720..94cb4ddc 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ -[![PyPI version](https://img.shields.io/pypi/v/trustgraph.svg)](https://pypi.org/project/trustgraph/) ![E2E Tests](https://github.com/trustgraph-ai/trustgraph/actions/workflows/release.yaml/badge.svg) +[![PyPI version](https://img.shields.io/pypi/v/trustgraph.svg)](https://pypi.org/project/trustgraph/) [![License](https://img.shields.io/github/license/trustgraph-ai/trustgraph?color=blue)](LICENSE) ![E2E Tests](https://github.com/trustgraph-ai/trustgraph/actions/workflows/release.yaml/badge.svg) [![Discord](https://img.shields.io/discord/1251652173201149994 )](https://discord.gg/sQMwkRz5GX) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/trustgraph-ai/trustgraph) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..a7093091 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,98 @@ +# Security Policy + +TrustGraph is an open-source AI graph processing pipeline. We take security +seriously and appreciate the responsible disclosure of vulnerabilities by the +community. + +## Supported Versions + +Security updates are only provided for the latest stable release line. +Versions prior to 2.0.0 are end-of-life and will not receive security patches. + +| Version | Supported | +| --------- | ------------------ | +| >= 2.0.0 | :white_check_mark: | +| < 2.0.0 | :x: | + +If you are running a version older than 2.0.0, we strongly recommend upgrading +to the latest 2.x release before reporting a vulnerability, as the issue may +already be resolved. + +## Reporting a Vulnerability + +**Please do not report security vulnerabilities through public GitHub issues, +pull requests, or discussions.** Doing so could expose users of TrustGraph to +risk before a fix is available. + +Instead, report security vulnerabilities by emailing the TrustGraph security +team directly: + +📧 **[info@trustgraph.ai](mailto:info@trustgraph.ai)** + +Please include as much of the following information as possible to help us +triage and resolve the issue quickly: + +- A clear description of the vulnerability and its potential impact +- The affected component(s) and version(s) of TrustGraph +- Step-by-step instructions to reproduce the issue +- Proof-of-concept code or a minimal reproducing example (if applicable) +- Any suggested mitigations or patches (if you have them) + +## What to Expect + +After submitting a report, you can expect the following process: + +1. **Acknowledgement** — We will acknowledge receipt of your report within + **72 hours**. +2. **Assessment** — We will investigate and assess the severity of the + vulnerability within **7 days** of acknowledgement. +3. **Updates** — We will keep you informed of our progress. If we need + additional information, we will reach out to you directly. +4. **Resolution** — Once a fix is developed and validated, we will coordinate + with you on the disclosure timeline before publishing a patch release. +5. **Credit** — With your permission, we will publicly credit your responsible + disclosure in the release notes accompanying the fix. + +## Severity Assessment + +We evaluate vulnerabilities using the +[CVSS v3.1](https://www.first.org/cvss/v3-1/) scoring system as a guide: + +| Severity | CVSS Score | Target Response Time | +| -------- | ---------- | -------------------- | +| Critical | 9.0 – 10.0 | 48 hours | +| High | 7.0 – 8.9 | 7 days | +| Medium | 4.0 – 6.9 | 30 days | +| Low | 0.1 – 3.9 | 90 days | + +## Scope + +The following are in scope for security reports: + +- Core TrustGraph Python packages (`trustgraph`, `trustgraph-base`, etc.) +- The TrustGraph REST/graph gateway and processing pipeline components +- Docker and Kubernetes deployment configurations shipped in this repository +- Authentication, authorization, or data isolation issues + +The following are **out of scope**: + +- Third-party services or infrastructure not maintained by TrustGraph +- Issues in upstream dependencies (please report those to the respective + project maintainers) +- Denial-of-service attacks requiring significant resources +- Social engineering attacks + +## Preferred Languages + +We prefer all security communications in **English**. + +## Policy + +TrustGraph follows the principle of +[Coordinated Vulnerability Disclosure (CVD)](https://vuls.cert.org/confluence/display/CVD). +We ask that you give us a reasonable amount of time to investigate and address +a reported vulnerability before any public disclosure. + +--- + +_Last reviewed: March 2026_ From 5a9db2da508e5d71597a99abe77ed6fc9911064f Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 30 Mar 2026 16:08:46 +0100 Subject: [PATCH 12/37] Add tg-monitor-prompts CLI tool for prompt queue monitoring (#737) Subscribes to prompt request/response Pulsar queues, correlates messages by ID, and logs a summary with template name, truncated terms, and elapsed time. Streaming responses are accumulated and shown at completion. Supports prompt and prompt-rag queue types. --- trustgraph-cli/pyproject.toml | 1 + .../trustgraph/cli/monitor_prompts.py | 344 ++++++++++++++++++ 2 files changed, 345 insertions(+) create mode 100644 trustgraph-cli/trustgraph/cli/monitor_prompts.py diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 42b9fb9b..9fd6bed7 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -35,6 +35,7 @@ tg-delete-kg-core = "trustgraph.cli.delete_kg_core:main" tg-delete-tool = "trustgraph.cli.delete_tool:main" tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main" tg-dump-queues = "trustgraph.cli.dump_queues:main" +tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main" tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" diff --git a/trustgraph-cli/trustgraph/cli/monitor_prompts.py b/trustgraph-cli/trustgraph/cli/monitor_prompts.py new file mode 100644 index 00000000..c412b643 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/monitor_prompts.py @@ -0,0 +1,344 @@ +""" +Monitor prompt request/response queues and log activity with timing. + +Subscribes to prompt request and response Pulsar queues, correlates +them by message ID, and logs a summary of each request/response with +elapsed time. Streaming responses are accumulated and shown once at +completion. + +Examples: + tg-monitor-prompts + tg-monitor-prompts --flow default --max-lines 5 + tg-monitor-prompts --queue-type prompt-rag +""" + +import json +import asyncio +import sys +import argparse +from datetime import datetime +from collections import OrderedDict + +import pulsar +from pulsar.schema import BytesSchema + + +default_flow = "default" +default_queue_type = "prompt" +default_max_lines = 3 +default_max_width = 80 + + +def truncate_text(text, max_lines, max_width): + """Truncate text to max_lines lines, each at most max_width chars.""" + if not text: + return "(empty)" + + lines = text.splitlines() + result = [] + for line in lines[:max_lines]: + if len(line) > max_width: + result.append(line[:max_width - 3] + "...") + else: + result.append(line) + + if len(lines) > max_lines: + result.append(f" ... ({len(lines) - max_lines} more lines)") + + return "\n".join(result) + + +def summarise_value(value, max_width): + """Summarise a term value — show type and size for large values.""" + # Try to parse JSON + try: + parsed = json.loads(value) + except (json.JSONDecodeError, TypeError): + parsed = value + + if isinstance(parsed, list): + return f"[{len(parsed)} items]" + elif isinstance(parsed, dict): + return f"{{{len(parsed)} keys}}" + elif isinstance(parsed, str): + if len(parsed) > max_width: + return parsed[:max_width - 3] + "..." + return parsed + else: + s = str(parsed) + if len(s) > max_width: + return s[:max_width - 3] + "..." + return s + + +def format_terms(terms, max_lines, max_width): + """Format prompt terms for display — concise summary.""" + if not terms: + return "" + + parts = [] + for key, value in terms.items(): + summary = summarise_value(value, max_width - len(key) - 4) + parts.append(f" {key}: {summary}") + + return "\n".join(parts) + + +def parse_raw_message(msg): + """Parse a raw Pulsar message into (correlation_id, body_dict).""" + try: + props = msg.properties() + corr_id = props.get("id", "") + except Exception: + corr_id = "" + + try: + value = msg.value() + if isinstance(value, bytes): + value = value.decode("utf-8") + body = json.loads(value) if isinstance(value, str) else {} + except Exception: + body = {} + + return corr_id, body + + +def receive_with_timeout(consumer, timeout_ms=500): + """Receive a message with timeout, returning None on timeout.""" + try: + return consumer.receive(timeout_millis=timeout_ms) + except Exception: + return None + + +async def monitor(flow, queue_type, max_lines, max_width, + pulsar_host, listener_name): + + request_queue = f"non-persistent://tg/request/{queue_type}:{flow}" + response_queue = f"non-persistent://tg/response/{queue_type}:{flow}" + + print(f"Monitoring prompt queues:") + print(f" Request: {request_queue}") + print(f" Response: {response_queue}") + print(f"Press Ctrl+C to stop\n") + + client = pulsar.Client( + pulsar_host, + listener_name=listener_name, + ) + + req_consumer = client.subscribe( + request_queue, + subscription_name="prompt-monitor-req", + consumer_type=pulsar.ConsumerType.Shared, + schema=BytesSchema(), + initial_position=pulsar.InitialPosition.Latest, + ) + + resp_consumer = client.subscribe( + response_queue, + subscription_name="prompt-monitor-resp", + consumer_type=pulsar.ConsumerType.Shared, + schema=BytesSchema(), + initial_position=pulsar.InitialPosition.Latest, + ) + + # Track in-flight requests: corr_id -> (timestamp, template_id) + in_flight = OrderedDict() + + # Accumulate streaming responses: corr_id -> list of text chunks + streaming_chunks = {} + + print("Listening...\n") + + try: + while True: + got_message = False + + # Poll request queue + msg = receive_with_timeout(req_consumer, 100) + if msg: + got_message = True + timestamp = datetime.now() + corr_id, body = parse_raw_message(msg) + time_str = timestamp.strftime("%H:%M:%S.%f")[:-3] + + template_id = body.get("id", "(unknown)") + terms = body.get("terms", {}) + streaming = body.get("streaming", False) + + in_flight[corr_id] = (timestamp, template_id) + + # Limit size + while len(in_flight) > 1000: + in_flight.popitem(last=False) + + stream_flag = " [streaming]" if streaming else "" + id_display = corr_id[:8] if corr_id else "--------" + print(f"[{time_str}] REQ {id_display} " + f"template={template_id}{stream_flag}") + + if terms: + print(format_terms(terms, max_lines, max_width)) + + req_consumer.acknowledge(msg) + + # Poll response queue + msg = receive_with_timeout(resp_consumer, 100) + if msg: + got_message = True + timestamp = datetime.now() + corr_id, body = parse_raw_message(msg) + time_str = timestamp.strftime("%H:%M:%S.%f")[:-3] + id_display = corr_id[:8] if corr_id else "--------" + + error = body.get("error") + text = body.get("text", "") + obj = body.get("object", "") + eos = body.get("end_of_stream", False) + + if error: + # Error — show immediately + elapsed_str = "" + if corr_id in in_flight: + req_timestamp, _ = in_flight.pop(corr_id) + elapsed = (timestamp - req_timestamp).total_seconds() + elapsed_str = f" ({elapsed:.3f}s)" + streaming_chunks.pop(corr_id, None) + + err_msg = error + if isinstance(error, dict): + err_msg = error.get("message", str(error)) + print(f"[{time_str}] ERR {id_display} " + f"{err_msg}{elapsed_str}") + + elif eos: + # End of stream — show accumulated text + timing + elapsed_str = "" + if corr_id in in_flight: + req_timestamp, _ = in_flight.pop(corr_id) + elapsed = (timestamp - req_timestamp).total_seconds() + elapsed_str = f" ({elapsed:.3f}s)" + + accumulated = streaming_chunks.pop(corr_id, []) + if text: + accumulated.append(text) + + full_text = "".join(accumulated) + if full_text: + truncated = truncate_text( + full_text, max_lines, max_width + ) + print(f"[{time_str}] RESP {id_display}" + f"{elapsed_str}") + print(f" {truncated}") + else: + print(f"[{time_str}] RESP {id_display}" + f"{elapsed_str} (empty)") + + elif text or obj: + # Streaming chunk or non-streaming response + if corr_id in streaming_chunks or ( + corr_id in in_flight + ): + # Accumulate streaming chunk + if corr_id not in streaming_chunks: + streaming_chunks[corr_id] = [] + streaming_chunks[corr_id].append(text or obj) + else: + # Non-streaming single response + elapsed_str = "" + if corr_id in in_flight: + req_timestamp, _ = in_flight.pop(corr_id) + elapsed = ( + timestamp - req_timestamp + ).total_seconds() + elapsed_str = f" ({elapsed:.3f}s)" + + content = text or obj + label = "" if text else " (object)" + truncated = truncate_text( + content, max_lines, max_width + ) + print(f"[{time_str}] RESP {id_display}" + f"{label}{elapsed_str}") + print(f" {truncated}") + + resp_consumer.acknowledge(msg) + + if not got_message: + await asyncio.sleep(0.05) + + except KeyboardInterrupt: + print("\nStopping...") + finally: + req_consumer.close() + resp_consumer.close() + client.close() + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-monitor-prompts", + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "-f", "--flow", + default=default_flow, + help=f"Flow ID (default: {default_flow})", + ) + + parser.add_argument( + "-q", "--queue-type", + default=default_queue_type, + help=f"Queue type: prompt or prompt-rag (default: {default_queue_type})", + ) + + parser.add_argument( + "-l", "--max-lines", + type=int, + default=default_max_lines, + help=f"Max lines of text per term/response (default: {default_max_lines})", + ) + + parser.add_argument( + "-w", "--max-width", + type=int, + default=default_max_width, + help=f"Max width per line (default: {default_max_width})", + ) + + parser.add_argument( + "--pulsar-host", + default="pulsar://localhost:6650", + help="Pulsar host URL (default: pulsar://localhost:6650)", + ) + + parser.add_argument( + "--listener-name", + default="localhost", + help="Pulsar listener name (default: localhost)", + ) + + args = parser.parse_args() + + try: + asyncio.run(monitor( + flow=args.flow, + queue_type=args.queue_type, + max_lines=args.max_lines, + max_width=args.max_width, + pulsar_host=args.pulsar_host, + listener_name=args.listener_name, + )) + except KeyboardInterrupt: + pass + except Exception as e: + print(f"Fatal error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() From 7af1d60db8ae20a6779e9f470352a1da5ff2bb76 Mon Sep 17 00:00:00 2001 From: CommitHu502Craft <139929317+CommitHu502Craft@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:58:58 +0800 Subject: [PATCH 13/37] fix(gateway): accept raw utf-8 text in text-load (#729) Co-authored-by: nanqinhu <139929317+nanqinhu@users.noreply.github.com> --- .../schemas/loading/TextLoadRequest.yaml | 3 +- specs/api/paths/flow/text-load.yaml | 15 ++++-- .../test_text_document_translator.py | 54 +++++++++++++++++++ .../messaging/translators/document_loading.py | 28 ++++++++-- 4 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 tests/unit/test_gateway/test_text_document_translator.py diff --git a/specs/api/components/schemas/loading/TextLoadRequest.yaml b/specs/api/components/schemas/loading/TextLoadRequest.yaml index 4ded87d5..447308d4 100644 --- a/specs/api/components/schemas/loading/TextLoadRequest.yaml +++ b/specs/api/components/schemas/loading/TextLoadRequest.yaml @@ -8,8 +8,7 @@ required: properties: text: type: string - description: Text content (base64 encoded) - format: byte + description: Text content, either raw text or base64 encoded for compatibility with older clients example: VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg== id: type: string diff --git a/specs/api/paths/flow/text-load.yaml b/specs/api/paths/flow/text-load.yaml index 5f918a3a..08bfe47b 100644 --- a/specs/api/paths/flow/text-load.yaml +++ b/specs/api/paths/flow/text-load.yaml @@ -8,7 +8,7 @@ post: ## Text Load Overview Fire-and-forget document loading: - - **Input**: Text content (base64 encoded) + - **Input**: Text content (raw UTF-8 or base64 encoded) - **Process**: Chunk, embed, store - **Output**: None (202 Accepted) @@ -26,7 +26,14 @@ post: ## Text Format - Text must be base64 encoded: + Text may be sent as raw UTF-8 text: + ``` + { + "text": "Cancer survival: 2.74× higher hazard ratio" + } + ``` + + Older clients may still send base64 encoded text: ``` text_content = "This is the document..." encoded = base64.b64encode(text_content.encode('utf-8')) @@ -78,12 +85,12 @@ post: simpleLoad: summary: Load text document value: - text: VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg== + text: This is the document text... id: doc-123 user: alice collection: research withMetadata: - summary: Load with RDF metadata + summary: Load with RDF metadata using base64 text value: text: UXVhbnR1bSBjb21wdXRpbmcgdXNlcyBxdWFudHVtIG1lY2hhbmljcyBwcmluY2lwbGVzLi4u id: doc-456 diff --git a/tests/unit/test_gateway/test_text_document_translator.py b/tests/unit/test_gateway/test_text_document_translator.py new file mode 100644 index 00000000..f836eb2b --- /dev/null +++ b/tests/unit/test_gateway/test_text_document_translator.py @@ -0,0 +1,54 @@ +""" +Unit tests for text document gateway translation compatibility. +""" + +import base64 + +from trustgraph.messaging.translators.document_loading import TextDocumentTranslator + + +class TestTextDocumentTranslator: + def test_to_pulsar_decodes_base64_text(self): + translator = TextDocumentTranslator() + payload = "Cancer survival: 2.74× higher hazard ratio" + + msg = translator.to_pulsar( + { + "id": "doc-1", + "user": "alice", + "collection": "research", + "charset": "utf-8", + "text": base64.b64encode(payload.encode("utf-8")).decode("ascii"), + } + ) + + assert msg.metadata.id == "doc-1" + assert msg.metadata.user == "alice" + assert msg.metadata.collection == "research" + assert msg.text == payload.encode("utf-8") + + def test_to_pulsar_accepts_raw_utf8_text(self): + translator = TextDocumentTranslator() + payload = "Cancer survival: 2.74× higher hazard ratio" + + msg = translator.to_pulsar( + { + "charset": "utf-8", + "text": payload, + } + ) + + assert msg.text == payload.encode("utf-8") + + def test_to_pulsar_falls_back_to_raw_non_base64_ascii(self): + translator = TextDocumentTranslator() + payload = "plain-text payload" + + msg = translator.to_pulsar( + { + "charset": "utf-8", + "text": payload, + } + ) + + assert msg.text == payload.encode("utf-8") diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index 7c2a013f..51cda697 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -4,6 +4,29 @@ from ...schema import Document, TextDocument, Chunk, DocumentEmbeddings, ChunkEm from .base import SendTranslator +def _decode_text_payload(payload: str | bytes, charset: str) -> str: + """ + Decode text-load payloads. + + Historical clients send base64-encoded text, but direct REST callers may + send raw UTF-8 text. Support both so Unicode text-load requests do not fail + at the gateway translation layer. + """ + if isinstance(payload, bytes): + if not payload.isascii(): + return payload.decode(charset) + candidate = payload.decode("ascii") + else: + if not payload.isascii(): + return payload + candidate = payload + + try: + return base64.b64decode(candidate, validate=True).decode(charset) + except (ValueError, UnicodeDecodeError): + return candidate + + class DocumentTranslator(SendTranslator): """Translator for Document schema objects (PDF docs etc.)""" @@ -49,8 +72,7 @@ class TextDocumentTranslator(SendTranslator): def to_pulsar(self, data: Dict[str, Any]) -> TextDocument: charset = data.get("charset", "utf-8") - # Text is base64 encoded in input - text = base64.b64decode(data["text"]).decode(charset) + text = _decode_text_payload(data["text"], charset) from ...schema import Metadata return TextDocument( @@ -169,4 +191,4 @@ class DocumentEmbeddingsTranslator(SendTranslator): result["metadata"] = metadata_dict - return result \ No newline at end of file + return result From 849987f0e686b4f24652e50ebdd7140573df1e72 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 00:32:49 +0100 Subject: [PATCH 14/37] Add multi-pattern orchestrator with plan-then-execute and supervisor (#739) Introduce an agent orchestrator service that supports three execution patterns (ReAct, plan-then-execute, supervisor) with LLM-based meta-routing to select the appropriate pattern and task type per request. Update the agent schema to support orchestration fields (correlation, sub-agents, plan steps) and remove legacy response fields (answer, thought, observation). --- docs/tech-specs/agent-orchestration.md | 939 ++++++++++++++++++ tests/contract/conftest.py | 7 +- tests/contract/test_message_contracts.py | 7 +- .../test_translator_completion_flags.py | 62 +- .../test_agent_service_non_streaming.py | 13 +- .../trustgraph/api/socket_client.py | 17 - .../trustgraph/base/agent_client.py | 3 +- .../trustgraph/base/agent_service.py | 3 - .../trustgraph/messaging/translators/agent.py | 49 +- .../trustgraph/schema/services/agent.py | 29 +- trustgraph-flow/pyproject.toml | 1 + .../trustgraph/agent/orchestrator/__init__.py | 2 + .../trustgraph/agent/orchestrator/__main__.py | 6 + .../agent/orchestrator/aggregator.py | 157 +++ .../agent/orchestrator/meta_router.py | 168 ++++ .../agent/orchestrator/pattern_base.py | 428 ++++++++ .../agent/orchestrator/plan_pattern.py | 349 +++++++ .../agent/orchestrator/react_pattern.py | 134 +++ .../trustgraph/agent/orchestrator/service.py | 511 ++++++++++ .../agent/orchestrator/supervisor_pattern.py | 214 ++++ .../trustgraph/agent/react/service.py | 79 +- 21 files changed, 3006 insertions(+), 172 deletions(-) create mode 100644 docs/tech-specs/agent-orchestration.md create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/__init__.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/__main__.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/service.py create mode 100644 trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py diff --git a/docs/tech-specs/agent-orchestration.md b/docs/tech-specs/agent-orchestration.md new file mode 100644 index 00000000..c93388ed --- /dev/null +++ b/docs/tech-specs/agent-orchestration.md @@ -0,0 +1,939 @@ +# TrustGraph Agent Orchestration — Technical Specification + +## Overview + +This specification describes the extension of TrustGraph's agent architecture +from a single ReACT execution pattern to a multi-pattern orchestration +model. The existing Pulsar-based self-queuing loop is pattern-agnostic — the +same infrastructure supports ReACT, Plan-then-Execute, Supervisor/Subagent +fan-out, and other execution strategies without changes to the message +transport. The extension adds a routing layer that selects the appropriate +pattern for each task, a set of pattern implementations that share common +iteration infrastructure, and a fan-out/fan-in mechanism for multi-agent +coordination. + +The central design principle is that +**trust and explainability are structural properties of the architecture**, +achieved by constraining LLM decisions to +graph-defined option sets and recording those constraints in the execution +trace. + +--- + +## Background + +### Existing Architecture + +The current agent manager is built on the ReACT pattern (Reasoning + Acting) +with these properties: + +- **Self-queuing loop**: Each iteration emits a new Pulsar message carrying + the accumulated history. The agent manager picks this up and runs the next + iteration. +- **Stateless agent manager**: No in-process state. All state lives in the + message payload. +- **Natural parallelism**: Multiple independent agent requests are handled + concurrently across Pulsar consumers. +- **Durability**: Crash recovery is inherent — the message survives process + failure. +- **Real-time feedback**: Streaming thought, action, observation and answer + chunks are emitted as iterations complete. +- **Tool calling and MCP invocation**: Tool calls into knowledge graphs, + external services, and MCP-connected systems. +- **Decision traces written to the knowledge graph**: Every iteration records + PROV-O triples — session, analysis, and conclusion entities — forming the + basis of explainability. + +### Current Message Flow + +``` +AgentRequest arrives (question, history=[], state, group, session_id) + │ + ▼ + Filter tools by group/state + │ + ▼ + AgentManager.react() → LLM call → parse → Action or Final + │ │ + │ [Action] │ [Final] + ▼ ▼ + Execute tool, capture observation Emit conclusion triples + Emit iteration triples Send AgentResponse + Append to history (end_of_dialog=True) + Emit new AgentRequest → "next" topic + │ + └── (picked up again by consumer, loop continues) +``` + +The key insight is that this loop structure is not ReACT-specific. The +plumbing — receive message, do work, emit next message — is the same +regardless of what the "work" step does. The payload and the pattern logic +define the behaviour; the infrastructure remains constant. + +### Current Limitations + +- Only one execution pattern (ReACT) is available regardless of task + characteristics. +- No mechanism for one agent to spawn and coordinate subagents. +- Pattern selection is implicit — every task gets the same treatment. +- The provenance model assumes a linear iteration chain (analysis N derives + from analysis N-1), with no support for parallel branches. + +--- + +## Design Goals + +- **Pattern-agnostic iteration infrastructure**: The self-queuing loop, tool + filtering, provenance emission, and streaming feedback should be shared + across all patterns. +- **Graph-constrained pattern selection**: The LLM selects patterns from a + graph-defined set, not from unconstrained reasoning. This makes the + selection auditable and explainable. +- **Genuinely parallel fan-out**: Subagent tasks execute concurrently on the + Pulsar queue, not sequentially in a single process. +- **Stateless coordination**: Fan-in uses the knowledge graph as coordination + substrate. The agent manager remains stateless. +- **Additive change**: The existing ReACT flow continues to work + unchanged. New patterns are added alongside it, not in place of it. + +--- + +## Patterns + +### ReACT as One Pattern Among Many + +ReACT is one point in a wider space of agent execution strategies: + +| Pattern | Structure | Strengths | +|---|---|---| +| **ReACT** | Interleaved reasoning and action | Adaptive, good for open-ended tasks | +| **Plan-then-Execute** | Decompose into a step DAG, then execute | More predictable, auditable plan | +| **Reflexion** | ReACT + self-critique after each action | Agents improve within the episode | +| **Supervisor/Subagent** | One agent orchestrates others | Parallel decomposition, synthesis | +| **Debate/Ensemble** | Multiple agents reason independently | Diverse perspectives, reconciliation | +| **LLM-as-router** | No reasoning loop, pure dispatch | Fast classification and routing | + +Not all of these need to be implemented at once. The architecture should +support them; the initial implementation delivers ReACT (already exists), +Plan-then-Execute, and Supervisor/Subagent. + +### Pattern Storage + +Patterns are stored as configuration items via the config API. They are +finite in number, mechanically well-defined, have enumerable properties, +and change slowly. Each pattern is a JSON object stored under the +`agent-pattern` config type. + +```json +Config type: "agent-pattern" +Config key: "react" +Value: { + "name": "react", + "description": "ReACT — Reasoning + Acting", + "when_to_use": "Adaptive, good for open-ended tasks" +} +``` + +These are written at deployment time and change rarely. If the architecture +later benefits from graph-based pattern storage (e.g. for richer ontological +relationships), the config items can be migrated to graph nodes — the +meta-router's selection logic is the same regardless of backend. + +--- + +## Task Types + +### What a Task Type Represents + +A **task type** characterises the problem domain — what the agent is being +asked to accomplish, and how a domain expert would frame it analytically. + +- Carries domain-specific methodology (e.g. "intelligence analysis always + applies structured analytic techniques") +- Pre-populates initial reasoning context via a framing prompt +- Constrains which patterns are valid for this class of problem +- Can define domain-specific termination criteria + +### Identification + +Task types are identified from plain-text task descriptions by the +LLM. Building a formal ontology over task descriptions is premature — natural +language is too varied and context-dependent. The LLM reads the description; +the graph provides the structure downstream. + +### Task Type Storage + +Task types are stored as configuration items via the config API under the +`agent-task-type` config type. Each task type is a JSON object that +references valid patterns by name. + +```json +Config type: "agent-task-type" +Config key: "risk-assessment" +Value: { + "name": "risk-assessment", + "description": "Due Diligence / Risk Assessment", + "framing_prompt": "Analyse across financial, reputational, legal and operational dimensions using structured analytic techniques.", + "valid_patterns": ["supervisor", "plan-then-execute", "react"], + "when_to_use": "Multi-dimensional analysis requiring structured assessment" +} +``` + +The `valid_patterns` list defines the constrained decision space — the LLM +can only select patterns that the task type's configuration says are valid. +This is the many-to-many relationship between task types and patterns, +expressed as configuration rather than graph edges. + +### Selection Flow + +``` +Task Description (plain text, from AgentRequest.question) + │ + │ [LLM interprets, constrained by available task types from config] + ▼ +Task Type (config item — domain framing and methodology) + │ + │ [config lookup — valid_patterns list] + ▼ +Pattern Candidates (config items) + │ + │ [LLM selects within constrained set, + │ informed by task description signals: + │ complexity, urgency, scope] + ▼ +Selected Pattern +``` + +The task description may carry modulating signals (complexity, urgency, scope) +that influence which pattern is selected within the constrained set. But the +raw description never directly selects a pattern — it always passes through +the task type layer first. + +--- + +## Explainability Through Constrained Decision Spaces + +A central principle of TrustGraph's explainability architecture is that +**explainability comes from constrained decision spaces**. + +When a decision is made from an unconstrained space — a raw LLM call with no +guardrails — the reasoning is opaque even if the LLM produces a rationale, +because that rationale is post-hoc and unverifiable. + +When a decision is made from a **constrained set defined in configuration**, +you can always answer: +- What valid options were available +- What criteria narrowed the set +- What signal made the final selection within that set + +This principle already governs the existing decision trace architecture and +extends naturally to pattern selection. The routing decision — which task type +and which pattern — is itself recorded as a provenance node, making the first +decision in the execution trace auditable. + +**Trust becomes a structural property of the architecture, not a claimed +property of the model.** + +--- + +## Orchestration Architecture + +### The Meta-Router + +The meta-router is the entry point for all agent requests. It runs as a +pre-processing step before the pattern-specific iteration loop begins. Its +job is to determine the task type and select the execution pattern. + +**When it runs**: On receipt of an `AgentRequest` with empty history (i.e. a +new task, not a continuation). Requests with non-empty history are already +mid-iteration and bypass the meta-router. + +**What it does**: + +1. Lists all available task types from the config API + (`config.list("agent-task-type")`). +2. Presents these to the LLM alongside the task description. The LLM + identifies which task type applies (or "general" as a fallback). +3. Reads the selected task type's configuration to get the `valid_patterns` + list. +4. Loads the candidate pattern definitions from config and presents them to + the LLM. The LLM selects one, influenced by signals in the task + description (complexity, number of independent dimensions, urgency). +5. Records the routing decision as a provenance node (see Provenance Model + below). +6. Populates the `AgentRequest` with the selected pattern, task type framing + prompt, and any pattern-specific configuration, then emits it onto the + queue. + +**Where it lives**: The meta-router is a phase within the agent-orchestrator, +not a separate service. The agent-orchestrator is a new executable that +uses the same service identity as the existing agent-manager-react, making +it a drop-in replacement on the same Pulsar queues. It includes the full +ReACT implementation alongside the new orchestration patterns. The +distinction between "route" and "iterate" is determined by whether the +request already has a pattern set. + +### Pattern Dispatch + +Once the meta-router has annotated the request with a pattern, the agent +manager dispatches to the appropriate pattern implementation. This is a +straightforward branch on the pattern field: + +``` +request arrives + │ + ├── history is empty → meta-router → annotate with pattern → re-emit + │ + └── history is non-empty (or pattern is set) + │ + ├── pattern = "react" → ReACT iteration + ├── pattern = "plan-then-execute" → PtE iteration + ├── pattern = "supervisor" → Supervisor iteration + └── (no pattern) → ReACT iteration (default) +``` + +Each pattern implementation follows the same contract: receive a request, do +one iteration of work, then either emit a "next" message (continue) or emit a +response (done). The self-queuing loop doesn't change. + +### Pattern Implementations + +#### ReACT (Existing) + +No changes. The existing `AgentManager.react()` path continues to work +as-is. + +#### Plan-then-Execute + +Two-phase pattern: + +**Planning phase** (first iteration): +- LLM receives the question plus task type framing. +- Produces a structured plan: an ordered list of steps, each with a goal, + expected tool, and dependencies on prior steps. +- The plan is recorded in the history as a special "plan" step. +- Emits a "next" message to begin execution. + +**Execution phase** (subsequent iterations): +- Reads the plan from history. +- Identifies the next unexecuted step. +- Executes that step (tool call + observation), similar to a single ReACT + action. +- Records the result against the plan step. +- If all steps complete, synthesises a final answer. +- If a step fails or produces unexpected results, the LLM can revise the + remaining plan (bounded re-planning, not a full restart). + +The plan lives in the history, so it travels with the message. No external +state is needed. + +#### Supervisor/Subagent + +The supervisor pattern introduces fan-out and fan-in. This is the most +architecturally significant addition. + +**Supervisor planning iteration**: +- LLM receives the question plus task type framing. +- Decomposes the task into independent subagent goals. +- For each subagent, emits a new `AgentRequest` with: + - A focused question (the subagent's specific goal) + - A shared correlation ID tying it to the parent task + - The subagent's own pattern (typically ReACT, but could be anything) + - Relevant context sliced from the parent request + +**Subagent execution**: +- Each subagent request is picked up by an agent manager instance and runs its + own independent iteration loop. +- Subagents are ordinary agent executions — they self-queue, use tools, emit + provenance, stream feedback. +- When a subagent reaches a Final answer, it writes a completion record to the + knowledge graph under the shared correlation ID. + +**Fan-in and synthesis**: +- An aggregator detects when all sibling subagents for a correlation ID have + completed. +- It emits a synthesis request to the supervisor carrying the correlation ID. +- The supervisor queries the graph for subagent results, reasons across them, + and decides whether to emit a final answer or iterate again. + +**Supervisor re-iteration**: +- After synthesis, the supervisor may determine that the results are + incomplete, contradictory, or reveal gaps requiring further investigation. +- Rather than emitting a final answer, it can fan out again with new or + refined subagent goals under a new correlation ID. This is the same + self-queuing loop — the supervisor emits new subagent requests and stops, + the aggregator detects completion, and synthesis runs again. +- The supervisor's iteration count (planning + synthesis rounds) is bounded + to prevent unbounded looping. + +This is detailed further in the Fan-Out / Fan-In section below. + +--- + +## Message Schema Evolution + +### Shared Schema Principle + +The `AgentRequest` and `AgentResponse` schemas are the shared contract +between the agent-manager (existing ReACT execution) and the +agent-orchestrator (meta-routing, supervisor, plan-then-execute). Both +services consume from the same *agent request* topic using the same +schema. Any schema changes must be reflected in both — the schema is +the integration point, not the service implementation. + +This means the orchestrator does not introduce separate message types for +its own use. Subagent requests, synthesis triggers, and meta-router +outputs are all `AgentRequest` messages with different field values. The +agent-manager ignores orchestration fields it doesn't use. + +### New Fields + +The `AgentRequest` schema needs new fields to carry orchestration +metadata. + +```python +@dataclass +class AgentRequest: + # Existing fields (unchanged) + question: str = "" + state: str = "" + group: list[str] | None = None + history: list[AgentStep] = field(default_factory=list) + user: str = "" + collection: str = "default" + streaming: bool = False + session_id: str = "" + + # New orchestration fields + conversation_id: str = "" # Optional caller-generated ID grouping related requests + pattern: str = "" # "react", "plan-then-execute", "supervisor", "" + task_type: str = "" # Identified task type name + framing: str = "" # Task type framing prompt injected into LLM context + correlation_id: str = "" # Shared ID linking subagents to parent + parent_session_id: str = "" # Parent's session_id (for subagents) + subagent_goal: str = "" # Focused goal for this subagent + expected_siblings: int = 0 # How many sibling subagents exist +``` + +The `AgentStep` schema also extends to accommodate non-ReACT iteration types: + +```python +@dataclass +class AgentStep: + # Existing fields (unchanged) + thought: str = "" + action: str = "" + arguments: dict[str, str] = field(default_factory=dict) + observation: str = "" + user: str = "" + + # New fields + step_type: str = "" # "react", "plan", "execute", "supervise", "synthesise" + plan: list[PlanStep] | None = None # For plan-then-execute: the structured plan + subagent_results: dict | None = None # For supervisor: collected subagent outputs +``` + +The `PlanStep` structure for Plan-then-Execute: + +```python +@dataclass +class PlanStep: + goal: str = "" # What this step should accomplish + tool_hint: str = "" # Suggested tool (advisory, not binding) + depends_on: list[int] = field(default_factory=list) # Indices of prerequisite steps + status: str = "pending" # "pending", "complete", "failed", "revised" + result: str = "" # Observation from execution +``` + +--- + +## Fan-Out and Fan-In + +### Why This Matters + +Fan-out is the mechanism that makes multi-agent coordination genuinely +parallel rather than simulated. With Pulsar, emitting multiple messages means +multiple consumers can pick them up concurrently. This is not threading or +async simulation — it is real distributed parallelism across agent manager +instances. + +### Fan-Out: Supervisor Emits Subagent Requests + +When a supervisor iteration decides to decompose a task, it: + +1. Generates a **correlation ID** — a UUID that groups the sibling subagents. +2. For each subagent, constructs a new `AgentRequest`: + - `question` = the subagent's focused goal (from `subagent_goal`) + - `correlation_id` = the shared correlation ID + - `parent_session_id` = the supervisor's session_id + - `pattern` = typically "react", but the supervisor can specify any pattern + - `session_id` = a new unique ID for this subagent's own provenance chain + - `expected_siblings` = total number of sibling subagents + - `history` = empty (fresh start, but framing context inherited) + - `group`, `user`, `collection` = inherited from parent +3. Emits each subagent request onto the agent request topic. +4. Records the fan-out decision in the provenance graph (see below). + +The supervisor then **stops**. It does not wait. It does not poll. It has +emitted its messages and its iteration is complete. The graph and the +aggregator handle the rest. + +### Fan-In: Graph-Based Completion Detection + +When a subagent reaches its Final answer, it writes a **completion node** to +the knowledge graph: + +``` +Completion node: + rdf:type tg:SubagentCompletion + tg:correlationId + tg:subagentSessionId + tg:parentSessionId + tg:subagentGoal + tg:result → + prov:wasGeneratedBy → +``` + +The **aggregator** is a component that watches for completion nodes. When it +detects that all expected siblings for a correlation ID have written +completion nodes, it: + +1. Collects all sibling results from the graph and librarian. +2. Constructs a **synthesis request** — a new `AgentRequest` addressed to the supervisor flow: + - `session_id` = the original supervisor's session_id + - `pattern` = "supervisor" + - `step_type` = "synthesise" (carried in history) + - `subagent_results` = the collected findings + - `history` = the supervisor's history up to the fan-out point, plus the synthesis step +3. Emits this onto the agent request topic. + +The supervisor picks this up, reasons across the aggregated findings, and +produces its final answer. + +### Aggregator Design + +The aggregator is event-driven, consistent with TrustGraph's Pulsar-based +architecture. Polling would be an anti-pattern in a system where all +coordination is message-driven. + +**Mechanism**: The aggregator is a Pulsar consumer on the explainability +topic. Subagent completion nodes are emitted as triples on this topic as +part of the existing provenance flow. When the aggregator receives a +`tg:SubagentCompletion` triple, it: + +1. Extracts the `tg:correlationId` from the completion node. +2. Queries the graph to count how many siblings for that correlation ID + have completed. +3. If all `expected_siblings` are present, triggers fan-in immediately — + collects results and emits the synthesis request. + +**State**: The aggregator is stateless in the same sense as the agent +manager — it holds no essential in-memory state. The graph is the source +of truth for completion counts. If the aggregator restarts, it can +re-process unacknowledged completion messages from Pulsar and re-check the +graph. No coordination state is lost. + +**Consistency**: Because the completion check queries the graph rather than +relying on an in-memory counter, the aggregator is tolerant of duplicate +messages, out-of-order delivery, and restarts. The graph query is +idempotent — asking "are all siblings complete?" gives the same answer +regardless of how many times or in what order the events arrive. + +### Timeout and Failure + +- **Subagent timeout**: The aggregator records the timestamp of the first + sibling completion (from the graph). A periodic timeout check (the one + concession to polling — but over local state, not the graph) detects + stalled correlation IDs. If `expected_siblings` completions are not + reached within a configurable timeout, the aggregator emits a partial + synthesis request with whatever results are available, flagging the + incomplete subagents. +- **Subagent failure**: If a subagent errors out, it writes an error + completion node (with `tg:status = "error"` and an error message). The + aggregator treats this as a completion — the supervisor receives the error + in its synthesis input and can reason about partial results. +- **Supervisor iteration limit**: The supervisor's own iteration count + (planning + synthesis) is bounded by `max_iterations` just like any other + pattern. + +--- + +## Provenance Model Extensions + +### Routing Decision + +The meta-router's task type and pattern selection is recorded as the first +provenance node in the session: + +``` +Routing node: + rdf:type prov:Entity, tg:RoutingDecision + prov:wasGeneratedBy → session (Question) activity + tg:taskType → TaskType node URI + tg:selectedPattern → Pattern node URI + tg:candidatePatterns → [Pattern node URIs] (what was available) + tg:routingRationale → document URI in librarian (LLM's reasoning) +``` + +This captures the constrained decision space: what candidates existed, which +was selected, and why. The candidates are graph-derived; the rationale is +LLM-generated but verifiable against the candidates. + +### Fan-Out Provenance + +When a supervisor fans out, the provenance records the decomposition: + +``` +FanOut node: + rdf:type prov:Entity, tg:FanOut + prov:wasDerivedFrom → supervisor's routing or planning iteration + tg:correlationId + tg:subagentGoals → [document URIs for each subagent goal] + tg:expectedSiblings +``` + +Each subagent's provenance chain is independent (its own session, iterations, +conclusion) but linked back to the parent via: + +``` +Subagent session: + rdf:type prov:Activity, tg:Question, tg:AgentQuestion + tg:parentCorrelationId + tg:parentSessionId +``` + +### Fan-In Provenance + +The synthesis step links back to all subagent conclusions: + +``` +Synthesis node: + rdf:type prov:Entity, tg:Synthesis + prov:wasDerivedFrom → [all subagent Conclusion entities] + tg:correlationId +``` + +This creates a DAG in the provenance graph: the supervisor's routing fans out +to N parallel subagent chains, which fan back in to a synthesis node. The +entire multi-agent execution is traceable from a single correlation ID. + +### URI Scheme + +Extending the existing `urn:trustgraph:agent:{session_id}` pattern: + +| Entity | URI Pattern | +|---|---| +| Session (existing) | `urn:trustgraph:agent:{session_id}` | +| Iteration (existing) | `urn:trustgraph:agent:{session_id}/i{n}` | +| Conclusion (existing) | `urn:trustgraph:agent:{session_id}/answer` | +| Routing decision | `urn:trustgraph:agent:{session_id}/routing` | +| Fan-out record | `urn:trustgraph:agent:{session_id}/fanout/{correlation_id}` | +| Subagent completion | `urn:trustgraph:agent:{session_id}/completion` | + +--- + +## Storage Responsibilities + +Pattern and task type definitions live in the config API. Runtime state and +provenance live in the knowledge graph. The division is: + +| Role | Storage | When Written | Content | +|---|---|---|---| +| Pattern definitions | Config API | At design time | Pattern properties, descriptions | +| Task type definitions | Config API | At design time | Domain framing, valid pattern lists | +| Routing decision trace | Knowledge graph | At request arrival | Why this task type and pattern were selected | +| Iteration decision trace | Knowledge graph | During execution | Each think/act/observe cycle, per existing model | +| Fan-out coordination | Knowledge graph | During fan-out | Subagent goals, correlation ID, expected count | +| Subagent completion | Knowledge graph | During fan-in | Per-subagent results under shared correlation ID | +| Execution audit trail | Knowledge graph | Post-execution | Full multi-agent reasoning trace as a DAG | + +The config API holds the definitions that constrain decisions. The knowledge +graph holds the runtime decisions and their provenance. The fan-in +coordination state is part of the provenance automatically — subagent +completion nodes are both coordination signals and audit trail entries. + +--- + +## Worked Example: Partner Risk Assessment + +**Request**: "Assess the risk profile of Company X as a potential partner" + +**1. Request arrives** on the *agent request* topic with empty history. +The agent manager picks it up. + +**2. Meta-router**: +- Queries config API, finds task types: *Risk Assessment*, *Research*, + *Summarisation*, *General*. +- LLM identifies *Risk Assessment*. Framing prompt loaded: "analyse across + financial, reputational, legal and operational dimensions using structured + analytic techniques." +- Valid patterns for *Risk Assessment*: [*Supervisor/Subagent*, + *Plan-then-Execute*, *ReACT*]. +- LLM selects *Supervisor/Subagent* — task has four independent investigative + dimensions, well-suited to parallel decomposition. +- Routing decision written to graph. Request re-emitted on the + *agent request* topic with `pattern="supervisor"`, framing populated. + +**3. Supervisor iteration** (picked up from *agent request* topic): +- LLM receives question + framing. Reasons that four independent investigative + threads are required. +- Generates correlation ID `corr-abc123`. +- Emits four subagent requests on the *agent request* topic: + - Financial analysis (pattern="react", subagent_goal="Analyse financial + health and stability of Company X") + - Legal analysis (pattern="react", subagent_goal="Review regulatory filings, + sanctions, and legal exposure for Company X") + - Reputational analysis (pattern="react", subagent_goal="Analyse news + sentiment and public reputation of Company X") + - Operational analysis (pattern="react", subagent_goal="Assess supply chain + dependencies and operational risks for Company X") +- Fan-out node written to graph. + +**4. Four subagents run in parallel** (each picked up from the *agent +request* topic by agent manager instances), each as an independent ReACT +loop: +- Financial — queries financial data services and knowledge graph + relationships +- Legal — searches regulatory filings and sanctions lists +- Reputational — searches news, analyses sentiment +- Operational — queries supply chain databases + +Each self-queues its iterations on the *agent request* topic. Each writes +its own decision trace to the graph as it progresses. Each completes +independently. + +**5. Fan-in**: +- Each subagent writes a `tg:SubagentCompletion` node to the graph on + completion, emitted on the *explainability* topic. The completion node + references the subagent's result document in the librarian. +- Aggregator (consuming the *explainability* topic) sees each completion + event. It queries the graph for the fan-out node to get the expected + sibling count, then checks how many completions exist for + `corr-abc123`. +- When all four siblings are complete, the aggregator emits a synthesis + request on the *agent request* topic with the correlation ID. It does + not fetch or bundle subagent results — the supervisor will query the + graph for those. + +**6. Supervisor synthesis** (picked up from *agent request* topic): +- Receives the synthesis trigger carrying the correlation ID. +- Queries the graph for `tg:SubagentCompletion` nodes under + `corr-abc123`, retrieving each subagent's goal and result document + reference. +- Fetches the result documents from the librarian. +- Reasons across all four dimensions, produces a structured risk + assessment with confidence scores. +- Emits final answer on the *agent response* topic and writes conclusion + provenance to the graph. + +**7. Response delivered** — the supervisor's synthesis streams on the +*agent response* topic as the LLM generates it, with `end_of_dialog` +on the final chunk. The collated answer is saved to the librarian and +referenced from conclusion provenance in the graph. The graph now holds +a complete, human-readable trace of the entire multi-agent execution — +from pattern selection through four parallel investigations to final +synthesis. + +--- + +## Class Hierarchy + +The agent-orchestrator executable (`agent-orchestrator`) uses the same +service identity as agent-manager-react, making it a drop-in replacement. +The pattern dispatch model suggests a class hierarchy where shared iteration +infrastructure lives in a base class and pattern-specific logic is in +subclasses: + +``` +AgentService (base — Pulsar consumer/producer specs, request handling) + │ + └── Processor (agent-orchestrator service) + │ + ├── MetaRouter — task type identification, pattern selection + │ + ├── PatternBase — shared: tool filtering, provenance, streaming, history + │ ├── ReactPattern — existing ReACT logic (extract from current AgentManager) + │ ├── PlanThenExecutePattern — plan phase + execute phase + │ └── SupervisorPattern — fan-out, synthesis + │ + └── Aggregator — fan-in completion detection +``` + +`PatternBase` captures what is currently spread across `Processor` and +`AgentManager`: tool filtering, LLM invocation, provenance triple emission, +streaming callbacks, history management. The pattern subclasses implement only +the decision logic specific to their execution strategy — what to do with the +LLM output, when to terminate, whether to fan out. + +This refactoring is not strictly necessary for the first iteration — the +meta-router and pattern dispatch could be added as branches within the +existing `Processor.agent_request()` method. But the class hierarchy clarifies +where shared vs. pattern-specific logic lives and will prevent duplication as +more patterns are added. + +--- + +## Configuration + +### Config API Seeding + +Pattern and task type definitions are stored via the config API and need to +be seeded at deployment time. This is analogous to how flow blueprints and +parameter types are loaded — a bootstrap step that writes the initial +configuration. + +The initial seed includes: + +**Patterns** (config type `agent-pattern`): +- `react` — interleaved reasoning and action +- `plan-then-execute` — structured plan followed by step execution +- `supervisor` — decomposition, fan-out to subagents, synthesis + +**Task types** (config type `agent-task-type`, initial set, expected to grow): +- `general` — no specific domain framing, all patterns valid +- `research` — open-ended investigation, valid patterns: react, plan-then-execute +- `risk-assessment` — multi-dimensional analysis, valid patterns: supervisor, + plan-then-execute, react +- `summarisation` — condense information, valid patterns: react + +The seed data is configuration, not code. It can be extended via the config +API (or the configuration UI) without redeploying the agent manager. + +### Migration Path + +The config API provides a practical starting point. If richer ontological +relationships between patterns, task types, and domain knowledge become +valuable, the definitions can be migrated to graph storage. The meta-router's +selection logic queries an abstract set of task types and patterns — the +storage backend is an implementation detail. + +### Fallback Behaviour + +If the config contains no patterns or task types: +- Task type defaults to `general`. +- Pattern defaults to `react`. +- The system degrades gracefully to existing behaviour. + +--- + +## Design Decisions + +| Decision | Resolution | Rationale | +|---|---|---| +| Task type identification | LLM interprets from plain text | Natural language too varied to formalise prematurely | +| Pattern/task type storage | Config API initially, graph later if needed | Avoids graph model complexity upfront; config API already has UI support; migration path is straightforward | +| Meta-router location | Phase within agent manager, not separate service | Avoids an extra network hop; routing is fast | +| Fan-in mechanism | Event-driven via explainability topic | Consistent with Pulsar-based architecture; graph query for completion count is idempotent and restart-safe | +| Aggregator deployment | Separate lightweight process | Decoupled from agent manager lifecycle | +| Subagent pattern selection | Supervisor specifies per-subagent | Supervisor has task context to make this choice | +| Plan storage | In message history | No external state needed; plan travels with message | +| Default pattern | Empty pattern field → ReACT | Sensible default when meta-router is not configured | + +--- + +## Streaming Protocol + +### Current Model + +The existing agent response schema has two levels: + +- **`end_of_message`** — marks the end of a complete thought, observation, + or answer. Chunks belonging to the same message arrive sequentially. +- **`end_of_dialog`** — marks the end of the entire agent execution. No + more messages will follow. + +This works because the current system produces messages serially — one +thought at a time, one agent at a time. + +### Fan-Out Breaks Serial Assumptions + +With supervisor/subagent fan-out, multiple subagents stream chunks +concurrently on the same *agent response* topic. The caller receives +interleaved chunks from different sources and needs to demultiplex them. + +### Resolution: Message ID + +Each chunk carries a `message_id` — a per-message UUID generated when +the agent begins streaming a new thought, observation, or answer. The +caller groups chunks by `message_id` and assembles each message +independently. + +``` +Response chunk fields: + message_id UUID for this message (groups chunks) + session_id Which agent session produced this chunk + chunk_type "thought" | "observation" | "answer" | ... + content The chunk text + end_of_message True on the final chunk of this message + end_of_dialog True on the final message of the entire execution +``` + +A single subagent emits multiple messages (thought, observation, thought, +answer), each with a distinct `message_id`. The `session_id` identifies +which subagent the message belongs to. The caller can display, group, or +filter by either. + +### Provenance Trigger + +`end_of_message` is the trigger for provenance storage. When a complete +message has been assembled from its chunks: + +1. The collated text is saved to the librarian as a single document. +2. A provenance node is written to the graph referencing the document URI. + +This follows the pattern established by GraphRAG, where streaming synthesis +chunks are delivered live but the stored provenance references the collated +answer text. Streaming is for the caller; provenance needs complete messages. + +--- + +## Open Questions + +- **Re-planning depth** (resolved): Runtime parameter on the + agent-orchestrator executable, default 2. Bounds how many times + Plan-then-Execute can revise its plan before forcing termination. +- **Nested fan-out** (phase B): A subagent can itself be a supervisor + that fans out further. The architecture supports this — correlation IDs + are independent and the aggregator is stateless. The protocols and + message schema should not preclude nested fan-out, but implementation + is deferred. Depth limits will need to be enforced to prevent runaway + decomposition. +- **Task type evolution** (resolved): Manually curated for now. See + Future Directions below for automated discovery. +- **Cost attribution** (deferred): Costs are measured at the + text-completion queue level as they are today. Per-request attribution + across subagents is not yet implemented and is not a blocker for + orchestration. +- **Conversation ID** (resolved): An optional `conversation_id` field on + `AgentRequest`, generated by the caller. When present, all objects + created during the execution (provenance nodes, librarian documents, + subagent completion records) are tagged with the conversation ID. This + enables querying all interactions in a conversation with a single + lookup, and provides the foundation for conversation-scoped memory. + No explicit open/close — the first request with a new conversation ID + implicitly starts the conversation. Omit for one-shot queries. +- **Tool scoping per subagent** (resolved): Subagents inherit the + parent's tool group by default. The supervisor can optionally override + the group per subagent to constrain capabilities (e.g. financial + subagent gets only financial tools). The `group` field on + `AgentRequest` already supports this — the supervisor just sets it + when constructing subagent requests. + +--- + +## Future Directions + +### Automated Task Type Discovery + +Task types are manually curated in the initial implementation. However, +the architecture is well-suited to automated discovery because all agent +requests and their execution traces flow through Pulsar topics. A +learning service could consume these messages and analyse patterns in +how tasks are framed, which patterns are selected, and how successfully +they execute. Over time, it could propose new task types based on +clusters of similar requests that don't map well to existing types, or +suggest refinements to framing prompts based on which framings lead to +better outcomes. This service would write proposed task types to the +config API for human review — automated discovery, manual approval. The +agent-orchestrator does not need to change; it always reads task types +from config regardless of how they got there. diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py index c474af29..15082437 100644 --- a/tests/contract/conftest.py +++ b/tests/contract/conftest.py @@ -87,10 +87,11 @@ def sample_message_data(): "history": [] }, "AgentResponse": { - "answer": "Machine learning is a subset of AI.", + "chunk_type": "answer", + "content": "Machine learning is a subset of AI.", + "end_of_message": True, + "end_of_dialog": True, "error": None, - "thought": "I need to provide information about machine learning.", - "observation": None }, "Metadata": { "id": "test-doc-123", diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index 695fef14..bc5bece1 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -212,10 +212,11 @@ class TestAgentMessageContracts: # Test required fields response = AgentResponse(**response_data) - assert hasattr(response, 'answer') + assert hasattr(response, 'chunk_type') + assert hasattr(response, 'content') + assert hasattr(response, 'end_of_message') + assert hasattr(response, 'end_of_dialog') assert hasattr(response, 'error') - assert hasattr(response, 'thought') - assert hasattr(response, 'observation') def test_agent_step_schema_contract(self): """Test AgentStep schema contract""" diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py index dc7d5748..a22e1c41 100644 --- a/tests/contract/test_translator_completion_flags.py +++ b/tests/contract/test_translator_completion_flags.py @@ -188,12 +188,10 @@ class TestAgentTranslatorCompletionFlags: # Arrange translator = TranslatorRegistry.get_response_translator("agent") response = AgentResponse( - answer="4", - error=None, - thought=None, - observation=None, + chunk_type="answer", + content="4", end_of_message=True, - end_of_dialog=True + end_of_dialog=True, ) # Act @@ -201,7 +199,7 @@ class TestAgentTranslatorCompletionFlags: # Assert assert is_final is True, "is_final must be True when end_of_dialog=True" - assert response_dict["answer"] == "4" + assert response_dict["content"] == "4" assert response_dict["end_of_dialog"] is True def test_agent_translator_is_final_with_end_of_dialog_false(self): @@ -212,12 +210,10 @@ class TestAgentTranslatorCompletionFlags: # Arrange translator = TranslatorRegistry.get_response_translator("agent") response = AgentResponse( - answer=None, - error=None, - thought="I need to solve this.", - observation=None, + chunk_type="thought", + content="I need to solve this.", end_of_message=True, - end_of_dialog=False + end_of_dialog=False, ) # Act @@ -225,31 +221,9 @@ class TestAgentTranslatorCompletionFlags: # Assert assert is_final is False, "is_final must be False when end_of_dialog=False" - assert response_dict["thought"] == "I need to solve this." + assert response_dict["content"] == "I need to solve this." assert response_dict["end_of_dialog"] is False - def test_agent_translator_is_final_fallback_with_answer(self): - """ - Test that AgentResponseTranslator returns is_final=True - when answer is present (fallback for legacy responses). - """ - # Arrange - translator = TranslatorRegistry.get_response_translator("agent") - # Legacy response without end_of_dialog flag - response = AgentResponse( - answer="4", - error=None, - thought=None, - observation=None - ) - - # Act - response_dict, is_final = translator.from_response_with_completion(response) - - # Assert - assert is_final is True, "is_final must be True when answer is present (legacy fallback)" - assert response_dict["answer"] == "4" - def test_agent_translator_intermediate_message_is_not_final(self): """ Test that intermediate messages (thought/observation) return is_final=False. @@ -259,12 +233,10 @@ class TestAgentTranslatorCompletionFlags: # Test thought message thought_response = AgentResponse( - answer=None, - error=None, - thought="Processing...", - observation=None, + chunk_type="thought", + content="Processing...", end_of_message=True, - end_of_dialog=False + end_of_dialog=False, ) # Act @@ -275,12 +247,10 @@ class TestAgentTranslatorCompletionFlags: # Test observation message observation_response = AgentResponse( - answer=None, - error=None, - thought=None, - observation="Result found", + chunk_type="observation", + content="Result found", end_of_message=True, - end_of_dialog=False + end_of_dialog=False, ) # Act @@ -302,10 +272,6 @@ class TestAgentTranslatorCompletionFlags: content="", end_of_message=True, end_of_dialog=True, - answer=None, - error=None, - thought=None, - observation=None ) # Act diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index 2ef64e96..ff630325 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -82,16 +82,16 @@ class TestAgentServiceNonStreaming: # Check thought message thought_response = sent_responses[0] assert isinstance(thought_response, AgentResponse) - assert thought_response.thought == "I need to solve this." - assert thought_response.answer is None + assert thought_response.chunk_type == "thought" + assert thought_response.content == "I need to solve this." assert thought_response.end_of_message is True, "Thought message must have end_of_message=True" assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False" # Check observation message observation_response = sent_responses[1] assert isinstance(observation_response, AgentResponse) - assert observation_response.observation == "The answer is 4." - assert observation_response.answer is None + assert observation_response.chunk_type == "observation" + assert observation_response.content == "The answer is 4." assert observation_response.end_of_message is True, "Observation message must have end_of_message=True" assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False" @@ -161,9 +161,8 @@ class TestAgentServiceNonStreaming: # Check final answer message answer_response = sent_responses[0] assert isinstance(answer_response, AgentResponse) - assert answer_response.answer == "4" - assert answer_response.thought is None - assert answer_response.observation is None + assert answer_response.chunk_type == "answer" + assert answer_response.content == "4" assert answer_response.end_of_message is True, "Final answer must have end_of_message=True" assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True" diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 40769fa0..e5f63c79 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -402,23 +402,6 @@ class SocketClient: content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) ) - # Non-streaming agent format: chunk_type is empty but has thought/observation/answer fields - elif resp.get("thought"): - return AgentThought( - content=resp.get("thought", ""), - end_of_message=resp.get("end_of_message", False) - ) - elif resp.get("observation"): - return AgentObservation( - content=resp.get("observation", ""), - end_of_message=resp.get("end_of_message", False) - ) - elif resp.get("answer"): - return AgentAnswer( - content=resp.get("answer", ""), - end_of_message=resp.get("end_of_message", False), - end_of_dialog=resp.get("end_of_dialog", False) - ) else: content = resp.get("response", resp.get("chunk", resp.get("text", ""))) return RAGChunk( diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py index f48fd024..d73d03b9 100644 --- a/trustgraph-base/trustgraph/base/agent_client.py +++ b/trustgraph-base/trustgraph/base/agent_client.py @@ -57,8 +57,7 @@ class AgentClient(RequestResponse): await self.request( AgentRequest( question = question, - plan = plan, - state = state, + state = state or "", history = history, ), recipient=recipient, diff --git a/trustgraph-base/trustgraph/base/agent_service.py b/trustgraph-base/trustgraph/base/agent_service.py index 0e5524fe..cbb15183 100644 --- a/trustgraph-base/trustgraph/base/agent_service.py +++ b/trustgraph-base/trustgraph/base/agent_service.py @@ -90,9 +90,6 @@ class AgentService(FlowProcessor): type = "agent-error", message = str(e), ), - thought = None, - observation = None, - answer = None, end_of_message = True, end_of_dialog = True, ), diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 378bdb41..b245a83e 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -16,6 +16,14 @@ class AgentRequestTranslator(MessageTranslator): collection=data.get("collection", "default"), streaming=data.get("streaming", False), session_id=data.get("session_id", ""), + conversation_id=data.get("conversation_id", ""), + pattern=data.get("pattern", ""), + task_type=data.get("task_type", ""), + framing=data.get("framing", ""), + correlation_id=data.get("correlation_id", ""), + parent_session_id=data.get("parent_session_id", ""), + subagent_goal=data.get("subagent_goal", ""), + expected_siblings=data.get("expected_siblings", 0), ) def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]: @@ -28,6 +36,14 @@ class AgentRequestTranslator(MessageTranslator): "collection": getattr(obj, "collection", "default"), "streaming": getattr(obj, "streaming", False), "session_id": getattr(obj, "session_id", ""), + "conversation_id": getattr(obj, "conversation_id", ""), + "pattern": getattr(obj, "pattern", ""), + "task_type": getattr(obj, "task_type", ""), + "framing": getattr(obj, "framing", ""), + "correlation_id": getattr(obj, "correlation_id", ""), + "parent_session_id": getattr(obj, "parent_session_id", ""), + "subagent_goal": getattr(obj, "subagent_goal", ""), + "expected_siblings": getattr(obj, "expected_siblings", 0), } @@ -40,24 +56,15 @@ class AgentResponseTranslator(MessageTranslator): def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]: result = {} - # Check if this is a streaming response (has chunk_type) - if hasattr(obj, 'chunk_type') and obj.chunk_type: + if obj.chunk_type: result["chunk_type"] = obj.chunk_type - if obj.content: - result["content"] = obj.content - result["end_of_message"] = getattr(obj, "end_of_message", False) - result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) - else: - # Legacy format (non-streaming) - if obj.answer: - result["answer"] = obj.answer - if obj.thought: - result["thought"] = obj.thought - if obj.observation: - result["observation"] = obj.observation - # Include completion flags for legacy format too - result["end_of_message"] = getattr(obj, "end_of_message", False) - result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) + if obj.content: + result["content"] = obj.content + result["end_of_message"] = getattr(obj, "end_of_message", False) + result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) + + if getattr(obj, "message_id", ""): + result["message_id"] = obj.message_id # Include explainability fields if present explain_id = getattr(obj, "explain_id", None) @@ -76,11 +83,5 @@ class AgentResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - # For streaming responses, check end_of_dialog - if hasattr(obj, 'chunk_type') and obj.chunk_type: - is_final = getattr(obj, 'end_of_dialog', False) - else: - # For legacy responses, check if answer is present - is_final = (obj.answer is not None) - + is_final = getattr(obj, 'end_of_dialog', False) return self.from_pulsar(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index 91179047..fdb9e391 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field +from typing import Optional from ..core.topic import topic from ..core.primitives import Error @@ -8,6 +9,14 @@ from ..core.primitives import Error # Prompt services, abstract the prompt generation +@dataclass +class PlanStep: + goal: str = "" + tool_hint: str = "" # Suggested tool for this step + depends_on: list[int] = field(default_factory=list) # Indices of prerequisite steps + status: str = "pending" # pending, running, completed, failed + result: str = "" # Result of step execution + @dataclass class AgentStep: thought: str = "" @@ -15,6 +24,9 @@ class AgentStep: arguments: dict[str, str] = field(default_factory=dict) observation: str = "" user: str = "" # User context for the step + step_type: str = "" # "react", "plan", "execute", "decompose", "synthesise" + plan: list[PlanStep] = field(default_factory=list) # Plan steps (for plan-then-execute) + subagent_results: dict[str, str] = field(default_factory=dict) # Subagent results keyed by goal @dataclass class AgentRequest: @@ -27,6 +39,16 @@ class AgentRequest: streaming: bool = False # Enable streaming response delivery (default false) session_id: str = "" # For provenance tracking across iterations + # Orchestration fields + conversation_id: str = "" # Groups related requests into a conversation + pattern: str = "" # Selected pattern: "react", "plan-then-execute", "supervisor" + task_type: str = "" # Task type from config: "general", "research", etc. + framing: str = "" # Domain framing text injected into prompts + correlation_id: str = "" # Links fan-out subagents to parent for fan-in + parent_session_id: str = "" # Session ID of the supervisor that spawned this subagent + subagent_goal: str = "" # Specific goal for a subagent (set by supervisor) + expected_siblings: int = 0 # Number of sibling subagents in this fan-out + @dataclass class AgentResponse: # Streaming-first design @@ -39,11 +61,10 @@ class AgentResponse: explain_id: str | None = None # Provenance URI (announced as created) explain_graph: str | None = None # Named graph where explain was stored - # Legacy fields (deprecated but kept for backward compatibility) - answer: str = "" + # Orchestration fields + message_id: str = "" # Unique ID for this response message + error: Error | None = None - thought: str = "" - observation: str = "" ############################################################################ diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 4aeb9199..66363305 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -56,6 +56,7 @@ Homepage = "https://github.com/trustgraph-ai/trustgraph" [project.scripts] agent-manager-react = "trustgraph.agent.react:run" +agent-orchestrator = "trustgraph.agent.orchestrator:run" api-gateway = "trustgraph.gateway:run" chunker-recursive = "trustgraph.chunking.recursive:run" chunker-token = "trustgraph.chunking.token:run" diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/__init__.py b/trustgraph-flow/trustgraph/agent/orchestrator/__init__.py new file mode 100644 index 00000000..214f7272 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/__init__.py @@ -0,0 +1,2 @@ + +from . service import * diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/__main__.py b/trustgraph-flow/trustgraph/agent/orchestrator/__main__.py new file mode 100644 index 00000000..da5a9021 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py new file mode 100644 index 00000000..bff8822c --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py @@ -0,0 +1,157 @@ +""" +Aggregator — monitors the explainability topic for subagent completions +and triggers synthesis when all siblings in a fan-out have completed. + +The aggregator watches for tg:Conclusion triples that carry a +correlation_id. When it detects that all expected siblings have +completed, it emits a synthesis AgentRequest on the agent request topic. +""" + +import asyncio +import json +import logging +import time +import uuid + +from ... schema import AgentRequest, AgentStep + +logger = logging.getLogger(__name__) + +# How long to wait for stalled correlations before giving up (seconds) +DEFAULT_TIMEOUT = 300 + + +class Aggregator: + """ + Tracks in-flight fan-out correlations and triggers synthesis + when all subagents complete. + + State is held in-memory; if the process restarts, in-flight + correlations are lost (acceptable for v1). + """ + + def __init__(self, timeout=DEFAULT_TIMEOUT): + self.timeout = timeout + + # correlation_id -> { + # "parent_session_id": str, + # "expected": int, + # "results": {goal: answer}, + # "request_template": AgentRequest or None, + # "created_at": float, + # } + self.correlations = {} + + def register_fanout(self, correlation_id, parent_session_id, + expected_siblings, request_template=None): + """ + Register a new fan-out. Called by the supervisor after emitting + subagent requests. + """ + self.correlations[correlation_id] = { + "parent_session_id": parent_session_id, + "expected": expected_siblings, + "results": {}, + "request_template": request_template, + "created_at": time.time(), + } + logger.info( + f"Aggregator: registered fan-out {correlation_id}, " + f"expecting {expected_siblings} subagents" + ) + + def record_completion(self, correlation_id, subagent_goal, result): + """ + Record a subagent completion. + + Returns: + True if all siblings are now complete, False otherwise. + Returns None if the correlation_id is unknown. + """ + if correlation_id not in self.correlations: + logger.warning( + f"Aggregator: unknown correlation_id {correlation_id}" + ) + return None + + entry = self.correlations[correlation_id] + entry["results"][subagent_goal] = result + + completed = len(entry["results"]) + expected = entry["expected"] + + logger.info( + f"Aggregator: {correlation_id} — " + f"{completed}/{expected} subagents complete" + ) + + return completed >= expected + + def get_results(self, correlation_id): + """Get all results for a correlation and remove the tracking entry.""" + entry = self.correlations.pop(correlation_id, None) + if entry is None: + return None, None, None + return ( + entry["results"], + entry["parent_session_id"], + entry["request_template"], + ) + + def build_synthesis_request(self, correlation_id, original_question, + user, collection): + """ + Build the AgentRequest that triggers the synthesis phase. + """ + results, parent_session_id, template = self.get_results(correlation_id) + + if results is None: + raise RuntimeError( + f"No results for correlation_id {correlation_id}" + ) + + # Build history with decompose step + results + synthesis_step = AgentStep( + thought="All subagents completed", + action="aggregate", + arguments={}, + observation=json.dumps(results), + step_type="synthesise", + subagent_results=results, + ) + + history = [] + if template and template.history: + history = list(template.history) + history.append(synthesis_step) + + return AgentRequest( + question=original_question, + state="", + group=template.group if template else [], + history=history, + user=user, + collection=collection, + streaming=template.streaming if template else False, + session_id=parent_session_id, + conversation_id=template.conversation_id if template else "", + pattern="supervisor", + task_type=template.task_type if template else "", + framing=template.framing if template else "", + correlation_id=correlation_id, + parent_session_id="", + subagent_goal="", + expected_siblings=0, + ) + + def cleanup_stale(self): + """Remove correlations that have timed out.""" + now = time.time() + stale = [ + cid for cid, entry in self.correlations.items() + if now - entry["created_at"] > self.timeout + ] + for cid in stale: + logger.warning(f"Aggregator: timing out stale correlation {cid}") + self.correlations.pop(cid, None) + return stale diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py new file mode 100644 index 00000000..c3b1afa6 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py @@ -0,0 +1,168 @@ +""" +MetaRouter — selects the task type and execution pattern for a query. + +Uses the config API to look up available task types and patterns, then +asks the LLM to classify the query and select the best pattern. +Falls back to ("react", "general", "") if config is empty. +""" + +import json +import logging + +logger = logging.getLogger(__name__) + +DEFAULT_PATTERN = "react" +DEFAULT_TASK_TYPE = "general" +DEFAULT_FRAMING = "" + + +class MetaRouter: + + def __init__(self, config=None): + """ + Args: + config: The full config dict from the config service. + May contain "agent-pattern" and "agent-task-type" keys. + """ + self.patterns = {} + self.task_types = {} + + if config: + # Load from config API + if "agent-pattern" in config: + for pid, pval in config["agent-pattern"].items(): + try: + self.patterns[pid] = json.loads(pval) + except (json.JSONDecodeError, TypeError): + self.patterns[pid] = {"name": pid} + + if "agent-task-type" in config: + for tid, tval in config["agent-task-type"].items(): + try: + self.task_types[tid] = json.loads(tval) + except (json.JSONDecodeError, TypeError): + self.task_types[tid] = {"name": tid} + + # If config has no patterns/task-types, default to react/general + if not self.patterns: + self.patterns = { + "react": {"name": "react", "description": "Interleaved reasoning and action"}, + } + if not self.task_types: + self.task_types = { + "general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""}, + } + + async def identify_task_type(self, question, context): + """ + Use the LLM to classify the question into one of the known task types. + + Args: + question: The user's query. + context: UserAwareContext (flow wrapper). + + Returns: + (task_type_id, framing) tuple. + """ + if len(self.task_types) <= 1: + tid = next(iter(self.task_types), DEFAULT_TASK_TYPE) + framing = self.task_types.get(tid, {}).get("framing", DEFAULT_FRAMING) + return tid, framing + + try: + client = context("prompt-request") + response = await client.prompt( + id="task-type-classify", + variables={ + "question": question, + "task_types": [ + {"name": tid, "description": tdata.get("description", tid)} + for tid, tdata in self.task_types.items() + ], + }, + ) + selected = response.strip().lower().replace('"', '').replace("'", "") + + if selected in self.task_types: + framing = self.task_types[selected].get("framing", DEFAULT_FRAMING) + logger.info(f"MetaRouter: identified task type '{selected}'") + return selected, framing + else: + logger.warning( + f"MetaRouter: LLM returned unknown task type '{selected}', " + f"falling back to '{DEFAULT_TASK_TYPE}'" + ) + except Exception as e: + logger.warning(f"MetaRouter: task type classification failed: {e}") + + framing = self.task_types.get(DEFAULT_TASK_TYPE, {}).get( + "framing", DEFAULT_FRAMING + ) + return DEFAULT_TASK_TYPE, framing + + async def select_pattern(self, question, task_type, context): + """ + Use the LLM to select the best execution pattern for this task type. + + Args: + question: The user's query. + task_type: The identified task type ID. + context: UserAwareContext (flow wrapper). + + Returns: + Pattern ID string. + """ + task_config = self.task_types.get(task_type, {}) + valid_patterns = task_config.get("valid_patterns", list(self.patterns.keys())) + + if len(valid_patterns) <= 1: + return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN + + try: + client = context("prompt-request") + response = await client.prompt( + id="pattern-select", + variables={ + "question": question, + "task_type": task_type, + "task_type_description": task_config.get("description", task_type), + "patterns": [ + {"name": pid, "description": self.patterns.get(pid, {}).get("description", pid)} + for pid in valid_patterns + if pid in self.patterns + ], + }, + ) + selected = response.strip().lower().replace('"', '').replace("'", "") + + if selected in valid_patterns: + logger.info(f"MetaRouter: selected pattern '{selected}'") + return selected + else: + logger.warning( + f"MetaRouter: LLM returned invalid pattern '{selected}', " + f"falling back to '{valid_patterns[0]}'" + ) + return valid_patterns[0] + except Exception as e: + logger.warning(f"MetaRouter: pattern selection failed: {e}") + return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN + + async def route(self, question, context): + """ + Full routing pipeline: identify task type, then select pattern. + + Args: + question: The user's query. + context: UserAwareContext (flow wrapper). + + Returns: + (pattern, task_type, framing) tuple. + """ + task_type, framing = await self.identify_task_type(question, context) + pattern = await self.select_pattern(question, task_type, context) + logger.info( + f"MetaRouter: route result — " + f"pattern={pattern}, task_type={task_type}, framing={framing!r}" + ) + return pattern, task_type, framing diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py new file mode 100644 index 00000000..fc07e745 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -0,0 +1,428 @@ +""" +Base class for agent patterns. + +Provides shared infrastructure used by all patterns: tool filtering, +provenance emission, streaming callbacks, history management, and +librarian integration. +""" + +import json +import logging +import uuid +from datetime import datetime + +from ... schema import AgentRequest, AgentResponse, AgentStep, Error +from ... schema import Triples, Metadata + +from trustgraph.provenance import ( + agent_session_uri, + agent_iteration_uri, + agent_thought_uri, + agent_observation_uri, + agent_final_uri, + agent_session_triples, + agent_iteration_triples, + agent_final_triples, + set_graph, + GRAPH_RETRIEVAL, +) + +from ..react.types import Action, Final +from ..tool_filter import filter_tools_by_group_and_state, get_next_state + +logger = logging.getLogger(__name__) + + +class UserAwareContext: + """Wraps flow interface to inject user context for tools that need it.""" + + def __init__(self, flow, user): + self._flow = flow + self._user = user + + def __call__(self, service_name): + client = self._flow(service_name) + if service_name in ( + "structured-query-request", + "row-embeddings-query-request", + ): + client._current_user = self._user + return client + + +class PatternBase: + """ + Shared infrastructure for all agent patterns. + + Subclasses implement iterate() to perform one iteration of their + pattern-specific logic. + """ + + def __init__(self, processor): + self.processor = processor + + def filter_tools(self, tools, request): + """Apply group/state filtering to the tool set.""" + return filter_tools_by_group_and_state( + tools=tools, + requested_groups=getattr(request, 'group', None), + current_state=getattr(request, 'state', None), + ) + + def make_context(self, flow, user): + """Create a user-aware context wrapper.""" + return UserAwareContext(flow, user) + + def build_history(self, request): + """Convert AgentStep history into Action objects.""" + if not request.history: + return [] + return [ + Action( + thought=h.thought, + name=h.action, + arguments=h.arguments, + observation=h.observation, + ) + for h in request.history + ] + + # ---- Streaming callbacks ------------------------------------------------ + + def make_think_callback(self, respond, streaming): + """Create the think callback for streaming/non-streaming.""" + async def think(x, is_final=False): + logger.debug(f"Think: {x} (is_final={is_final})") + if streaming: + r = AgentResponse( + chunk_type="thought", + content=x, + end_of_message=is_final, + end_of_dialog=False, + ) + else: + r = AgentResponse( + chunk_type="thought", + content=x, + end_of_message=True, + end_of_dialog=False, + ) + await respond(r) + return think + + def make_observe_callback(self, respond, streaming): + """Create the observe callback for streaming/non-streaming.""" + async def observe(x, is_final=False): + logger.debug(f"Observe: {x} (is_final={is_final})") + if streaming: + r = AgentResponse( + chunk_type="observation", + content=x, + end_of_message=is_final, + end_of_dialog=False, + ) + else: + r = AgentResponse( + chunk_type="observation", + content=x, + end_of_message=True, + end_of_dialog=False, + ) + await respond(r) + return observe + + def make_answer_callback(self, respond, streaming): + """Create the answer callback for streaming/non-streaming.""" + async def answer(x): + logger.debug(f"Answer: {x}") + if streaming: + r = AgentResponse( + chunk_type="answer", + content=x, + end_of_message=False, + end_of_dialog=False, + ) + else: + r = AgentResponse( + chunk_type="answer", + content=x, + end_of_message=True, + end_of_dialog=False, + ) + await respond(r) + return answer + + # ---- Provenance emission ------------------------------------------------ + + async def emit_session_triples(self, flow, session_uri, question, user, + collection, respond, streaming): + """Emit provenance triples for a new session.""" + timestamp = datetime.utcnow().isoformat() + "Z" + triples = set_graph( + agent_session_triples(session_uri, question, timestamp), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=session_uri, + user=user, + collection=collection, + ), + triples=triples, + )) + logger.debug(f"Emitted session triples for {session_uri}") + + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=session_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + + async def emit_iteration_triples(self, flow, session_id, iteration_num, + session_uri, act, request, respond, + streaming): + """Emit provenance triples for an iteration and save to librarian.""" + iteration_uri = agent_iteration_uri(session_id, iteration_num) + + if iteration_num > 1: + iter_question_uri = None + iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) + else: + iter_question_uri = session_uri + iter_previous_uri = None + + # Save thought to librarian + thought_doc_id = None + if act.thought: + thought_doc_id = ( + f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" + ) + try: + await self.processor.save_answer_content( + doc_id=thought_doc_id, + user=request.user, + content=act.thought, + title=f"Agent Thought: {act.name}", + ) + except Exception as e: + logger.warning(f"Failed to save thought to librarian: {e}") + thought_doc_id = None + + # Save observation to librarian + observation_doc_id = None + if act.observation: + observation_doc_id = ( + f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" + ) + try: + await self.processor.save_answer_content( + doc_id=observation_doc_id, + user=request.user, + content=act.observation, + title=f"Agent Observation: {act.name}", + ) + except Exception as e: + logger.warning(f"Failed to save observation to librarian: {e}") + observation_doc_id = None + + thought_entity_uri = agent_thought_uri(session_id, iteration_num) + observation_entity_uri = agent_observation_uri(session_id, iteration_num) + + iter_triples = set_graph( + agent_iteration_triples( + iteration_uri, + question_uri=iter_question_uri, + previous_uri=iter_previous_uri, + action=act.name, + arguments=act.arguments, + thought_uri=thought_entity_uri if thought_doc_id else None, + thought_document_id=thought_doc_id, + observation_uri=observation_entity_uri if observation_doc_id else None, + observation_document_id=observation_doc_id, + ), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=iteration_uri, + user=request.user, + collection=getattr(request, 'collection', 'default'), + ), + triples=iter_triples, + )) + logger.debug(f"Emitted iteration triples for {iteration_uri}") + + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=iteration_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + + async def emit_final_triples(self, flow, session_id, iteration_num, + session_uri, answer_text, request, respond, + streaming): + """Emit provenance triples for the final answer and save to librarian.""" + final_uri = agent_final_uri(session_id) + + if iteration_num > 1: + final_question_uri = None + final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) + else: + final_question_uri = session_uri + final_previous_uri = None + + # Save answer to librarian + answer_doc_id = None + if answer_text: + answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer" + try: + await self.processor.save_answer_content( + doc_id=answer_doc_id, + user=request.user, + content=answer_text, + title=f"Agent Answer: {request.question[:50]}...", + ) + logger.debug(f"Saved answer to librarian: {answer_doc_id}") + except Exception as e: + logger.warning(f"Failed to save answer to librarian: {e}") + answer_doc_id = None + + final_triples = set_graph( + agent_final_triples( + final_uri, + question_uri=final_question_uri, + previous_uri=final_previous_uri, + document_id=answer_doc_id, + ), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=final_uri, + user=request.user, + collection=getattr(request, 'collection', 'default'), + ), + triples=final_triples, + )) + logger.debug(f"Emitted final triples for {final_uri}") + + if streaming: + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=final_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + + # ---- Response helpers --------------------------------------------------- + + async def prompt_as_answer(self, client, prompt_id, variables, + respond, streaming): + """Call a prompt template, forwarding chunks as answer + AgentResponse messages when streaming is enabled. + + Returns the full accumulated answer text (needed for provenance). + """ + if streaming: + accumulated = [] + + async def on_chunk(text, end_of_stream): + if text: + accumulated.append(text) + await respond(AgentResponse( + chunk_type="answer", + content=text, + end_of_message=False, + end_of_dialog=False, + )) + + await client.prompt( + id=prompt_id, + variables=variables, + streaming=True, + chunk_callback=on_chunk, + ) + + return "".join(accumulated) + else: + return await client.prompt( + id=prompt_id, + variables=variables, + ) + + async def send_final_response(self, respond, streaming, answer_text, + already_streamed=False): + """Send the answer content and end-of-dialog marker. + + Args: + already_streamed: If True, answer chunks were already sent + via streaming callbacks (e.g. ReactPattern). Only the + end-of-dialog marker is emitted. + """ + if streaming and not already_streamed: + # Answer wasn't streamed yet — send it as a chunk first + if answer_text: + await respond(AgentResponse( + chunk_type="answer", + content=answer_text, + end_of_message=False, + end_of_dialog=False, + )) + if streaming: + # End-of-dialog marker + await respond(AgentResponse( + chunk_type="answer", + content="", + end_of_message=True, + end_of_dialog=True, + )) + else: + await respond(AgentResponse( + chunk_type="answer", + content=answer_text, + end_of_message=True, + end_of_dialog=True, + )) + + def build_next_request(self, request, history, session_id, collection, + streaming, next_state): + """Build the AgentRequest for the next iteration.""" + return AgentRequest( + question=request.question, + state=next_state, + group=getattr(request, 'group', []), + history=[ + AgentStep( + thought=h.thought, + action=h.name, + arguments={k: str(v) for k, v in h.arguments.items()}, + observation=h.observation, + ) + for h in history + ], + user=request.user, + collection=collection, + streaming=streaming, + session_id=session_id, + # Preserve orchestration fields + conversation_id=getattr(request, 'conversation_id', ''), + pattern=getattr(request, 'pattern', ''), + task_type=getattr(request, 'task_type', ''), + framing=getattr(request, 'framing', ''), + correlation_id=getattr(request, 'correlation_id', ''), + parent_session_id=getattr(request, 'parent_session_id', ''), + subagent_goal=getattr(request, 'subagent_goal', ''), + expected_siblings=getattr(request, 'expected_siblings', 0), + ) + + async def iterate(self, request, respond, next, flow): + """ + Perform one iteration of this pattern. + + Must be implemented by subclasses. + """ + raise NotImplementedError diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py new file mode 100644 index 00000000..d5f667c8 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -0,0 +1,349 @@ +""" +PlanThenExecutePattern — structured planning followed by step execution. + +Phase 1 (planning): LLM produces a structured plan of steps. +Phase 2 (execution): Each step is executed via single-shot tool call. +""" + +import json +import logging +import uuid + +from ... schema import AgentRequest, AgentResponse, AgentStep, PlanStep + +from ..react.types import Action + +from . pattern_base import PatternBase + +logger = logging.getLogger(__name__) + + +class PlanThenExecutePattern(PatternBase): + """ + Plan-then-Execute pattern. + + History tracks the current phase via AgentStep.step_type: + - "plan" step: contains the plan in step.plan + - "execute" step: a normal react iteration executing a plan step + + On the first call (empty history), a planning iteration is run. + Subsequent calls execute the next pending plan step via ReACT. + """ + + async def iterate(self, request, respond, next, flow): + + streaming = getattr(request, 'streaming', False) + session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) + collection = getattr(request, 'collection', 'default') + + history = self.build_history(request) + iteration_num = len(request.history) + 1 + session_uri = self.processor.provenance_session_uri(session_id) + + # Emit session provenance on first iteration + if iteration_num == 1: + await self.emit_session_triples( + flow, session_uri, request.question, + request.user, collection, respond, streaming, + ) + + logger.info( + f"PlanThenExecutePattern iteration {iteration_num}: " + f"{request.question}" + ) + + if iteration_num >= self.processor.max_iterations: + raise RuntimeError("Too many agent iterations") + + # Determine current phase by checking history for a plan step + plan = self._extract_plan(request.history) + + if plan is None: + await self._planning_iteration( + request, respond, next, flow, + session_id, collection, streaming, session_uri, + iteration_num, + ) + else: + await self._execution_iteration( + request, respond, next, flow, + session_id, collection, streaming, session_uri, + iteration_num, plan, + ) + + def _extract_plan(self, history): + """Find the most recent plan from history. + + Checks execute steps first (they carry the updated plan with + completion statuses), then falls back to the original plan step. + """ + if not history: + return None + for step in reversed(history): + if step.plan: + return list(step.plan) + return None + + def _find_next_pending_step(self, plan): + """Return index of the next pending step, or None if all done.""" + for i, step in enumerate(plan): + if getattr(step, 'status', 'pending') == 'pending': + return i + return None + + async def _planning_iteration(self, request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num): + """Ask the LLM to produce a structured plan.""" + + think = self.make_think_callback(respond, streaming) + + tools = self.filter_tools(self.processor.agent.tools, request) + framing = getattr(request, 'framing', '') + + context = self.make_context(flow, request.user) + client = context("prompt-request") + + # Use the plan-create prompt template + plan_steps = await client.prompt( + id="plan-create", + variables={ + "question": request.question, + "framing": framing, + "tools": [ + {"name": t.name, "description": t.description} + for t in tools.values() + ], + }, + ) + + # Validate we got a list + if not isinstance(plan_steps, list) or not plan_steps: + logger.warning("plan-create returned invalid result, falling back to single step") + plan_steps = [{"goal": "Answer the question directly", "tool_hint": "", "depends_on": []}] + + # Emit thought about the plan + thought_text = f"Created plan with {len(plan_steps)} steps" + await think(thought_text, is_final=True) + + # Build PlanStep objects + plan_agent_steps = [ + PlanStep( + goal=ps.get("goal", ""), + tool_hint=ps.get("tool_hint", ""), + depends_on=ps.get("depends_on", []), + status="pending", + result="", + ) + for ps in plan_steps + ] + + # Create a plan step in history + plan_history_step = AgentStep( + thought=thought_text, + action="plan", + arguments={}, + observation=json.dumps(plan_steps), + step_type="plan", + plan=plan_agent_steps, + ) + + # Build next request with plan in history + new_history = list(request.history) + [plan_history_step] + r = AgentRequest( + question=request.question, + state=request.state, + group=getattr(request, 'group', []), + history=new_history, + user=request.user, + collection=collection, + streaming=streaming, + session_id=session_id, + conversation_id=getattr(request, 'conversation_id', ''), + pattern=getattr(request, 'pattern', ''), + task_type=getattr(request, 'task_type', ''), + framing=getattr(request, 'framing', ''), + correlation_id=getattr(request, 'correlation_id', ''), + parent_session_id=getattr(request, 'parent_session_id', ''), + subagent_goal=getattr(request, 'subagent_goal', ''), + expected_siblings=getattr(request, 'expected_siblings', 0), + ) + await next(r) + + async def _execution_iteration(self, request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num, plan): + """Execute the next pending plan step via single-shot tool call.""" + + pending_idx = self._find_next_pending_step(plan) + + if pending_idx is None: + # All steps done — synthesise final answer + await self._synthesise( + request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num, plan, + ) + return + + current_step = plan[pending_idx] + goal = getattr(current_step, 'goal', '') or str(current_step) + + logger.info(f"Executing plan step {pending_idx}: {goal}") + + think = self.make_think_callback(respond, streaming) + observe = self.make_observe_callback(respond, streaming) + + # Gather results from dependencies + previous_results = [] + depends_on = getattr(current_step, 'depends_on', []) + if depends_on: + for dep_idx in depends_on: + if 0 <= dep_idx < len(plan): + dep_step = plan[dep_idx] + dep_result = getattr(dep_step, 'result', '') + if dep_result: + previous_results.append({ + "index": dep_idx, + "result": dep_result, + }) + + tools = self.filter_tools(self.processor.agent.tools, request) + context = self.make_context(flow, request.user) + client = context("prompt-request") + + # Single-shot: ask LLM which tool + arguments to use for this goal + tool_call = await client.prompt( + id="plan-step-execute", + variables={ + "goal": goal, + "previous_results": previous_results, + "tools": [ + { + "name": t.name, + "description": t.description, + "arguments": [ + {"name": a.name, "type": a.type, "description": a.description} + for a in t.arguments + ], + } + for t in tools.values() + ], + }, + ) + + tool_name = tool_call.get("tool", "") + tool_arguments = tool_call.get("arguments", {}) + + await think( + f"Step {pending_idx}: {goal} → calling {tool_name}", + is_final=True, + ) + + # Invoke the tool directly + if tool_name in tools: + tool = tools[tool_name] + resp = await tool.implementation(context).invoke(**tool_arguments) + step_result = resp.strip() if isinstance(resp, str) else str(resp).strip() + else: + logger.warning( + f"Plan step {pending_idx}: LLM selected unknown tool " + f"'{tool_name}', available: {list(tools.keys())}" + ) + step_result = f"Error: tool '{tool_name}' not found" + + await observe(step_result, is_final=True) + + # Update plan step status + plan[pending_idx] = PlanStep( + goal=goal, + tool_hint=getattr(current_step, 'tool_hint', ''), + depends_on=getattr(current_step, 'depends_on', []), + status="completed", + result=step_result, + ) + + # Emit iteration provenance + prov_act = Action( + thought=f"Plan step {pending_idx}: {goal}", + name=tool_name, + arguments=tool_arguments, + observation=step_result, + ) + await self.emit_iteration_triples( + flow, session_id, iteration_num, session_uri, + prov_act, request, respond, streaming, + ) + + # Build execution step for history + exec_step = AgentStep( + thought=f"Executing plan step {pending_idx}: {goal}", + action=tool_name, + arguments={k: str(v) for k, v in tool_arguments.items()}, + observation=step_result, + step_type="execute", + plan=plan, + ) + + new_history = list(request.history) + [exec_step] + + r = AgentRequest( + question=request.question, + state=request.state, + group=getattr(request, 'group', []), + history=new_history, + user=request.user, + collection=collection, + streaming=streaming, + session_id=session_id, + conversation_id=getattr(request, 'conversation_id', ''), + pattern=getattr(request, 'pattern', ''), + task_type=getattr(request, 'task_type', ''), + framing=getattr(request, 'framing', ''), + correlation_id=getattr(request, 'correlation_id', ''), + parent_session_id=getattr(request, 'parent_session_id', ''), + subagent_goal=getattr(request, 'subagent_goal', ''), + expected_siblings=getattr(request, 'expected_siblings', 0), + ) + await next(r) + + async def _synthesise(self, request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num, plan): + """Synthesise a final answer from all completed plan step results.""" + + think = self.make_think_callback(respond, streaming) + framing = getattr(request, 'framing', '') + + context = self.make_context(flow, request.user) + client = context("prompt-request") + + # Use the plan-synthesise prompt template + steps_data = [] + for i, step in enumerate(plan): + steps_data.append({ + "index": i, + "goal": getattr(step, 'goal', f'Step {i}'), + "result": getattr(step, 'result', ''), + }) + + await think("Synthesising final answer from plan results", is_final=True) + + response_text = await self.prompt_as_answer( + client, "plan-synthesise", + variables={ + "question": request.question, + "framing": framing, + "steps": steps_data, + }, + respond=respond, + streaming=streaming, + ) + + await self.emit_final_triples( + flow, session_id, iteration_num, session_uri, + response_text, request, respond, streaming, + ) + await self.send_final_response( + respond, streaming, response_text, already_streamed=streaming, + ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py new file mode 100644 index 00000000..c0e481f7 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -0,0 +1,134 @@ +""" +ReactPattern — extracted from the existing agent_manager.py. + +Implements the ReACT (Reasoning + Acting) loop: think, select a tool, +observe the result, repeat until a final answer is produced. +""" + +import json +import logging +import uuid + +from ... schema import AgentRequest, AgentResponse, AgentStep + +from ..react.agent_manager import AgentManager +from ..react.types import Action, Final +from ..tool_filter import get_next_state + +from . pattern_base import PatternBase + +logger = logging.getLogger(__name__) + + +class ReactPattern(PatternBase): + """ + ReACT pattern: interleaved reasoning and action. + + Each iterate() call performs one reason/act cycle. If the LLM + produces a Final answer the dialog completes; otherwise the action + result is appended to history and a next-request is emitted. + """ + + async def iterate(self, request, respond, next, flow): + + streaming = getattr(request, 'streaming', False) + session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) + collection = getattr(request, 'collection', 'default') + + history = self.build_history(request) + iteration_num = len(history) + 1 + session_uri = self.processor.provenance_session_uri(session_id) + + # Emit session provenance on first iteration + if iteration_num == 1: + await self.emit_session_triples( + flow, session_uri, request.question, + request.user, collection, respond, streaming, + ) + + logger.info(f"ReactPattern iteration {iteration_num}: {request.question}") + + if len(history) >= self.processor.max_iterations: + raise RuntimeError("Too many agent iterations") + + # Build callbacks + think = self.make_think_callback(respond, streaming) + observe = self.make_observe_callback(respond, streaming) + answer_cb = self.make_answer_callback(respond, streaming) + + # Filter tools + filtered_tools = self.filter_tools( + self.processor.agent.tools, request, + ) + logger.info( + f"Filtered from {len(self.processor.agent.tools)} " + f"to {len(filtered_tools)} available tools" + ) + + # Create temporary agent with filtered tools and optional framing + additional_context = self.processor.agent.additional_context + framing = getattr(request, 'framing', '') + if framing: + if additional_context: + additional_context = f"{additional_context}\n\n{framing}" + else: + additional_context = framing + + temp_agent = AgentManager( + tools=filtered_tools, + additional_context=additional_context, + ) + + context = self.make_context(flow, request.user) + + act = await temp_agent.react( + question=request.question, + history=history, + think=think, + observe=observe, + answer=answer_cb, + context=context, + streaming=streaming, + ) + + logger.debug(f"Action: {act}") + + if isinstance(act, Final): + + if isinstance(act.final, str): + f = act.final + else: + f = json.dumps(act.final) + + # Emit final provenance + await self.emit_final_triples( + flow, session_id, iteration_num, session_uri, + f, request, respond, streaming, + ) + + await self.send_final_response( + respond, streaming, f, already_streamed=streaming, + ) + return + + # Not final — emit iteration provenance and send next request + await self.emit_iteration_triples( + flow, session_id, iteration_num, session_uri, + act, request, respond, streaming, + ) + + history.append(act) + + # Handle state transitions + next_state = request.state + if act.name in filtered_tools: + executed_tool = filtered_tools[act.name] + next_state = get_next_state(executed_tool, request.state or "undefined") + + r = self.build_next_request( + request, history, session_id, collection, + streaming, next_state, + ) + await next(r) + + logger.debug("ReactPattern iteration complete") diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py new file mode 100644 index 00000000..f7418e60 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -0,0 +1,511 @@ +""" +Agent orchestrator service — multi-pattern drop-in replacement for +agent-manager-react. + +Uses the same service identity and Pulsar queues. Adds meta-routing +to select between ReactPattern, PlanThenExecutePattern, and +SupervisorPattern at runtime. +""" + +import asyncio +import base64 +import json +import functools +import logging +import uuid +from datetime import datetime + +from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec +from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec +from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec +from ... base import ProducerSpec +from ... base import Consumer, Producer +from ... base import ConsumerMetrics, ProducerMetrics + +from ... schema import AgentRequest, AgentResponse, AgentStep, Error +from ... schema import Triples, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue + +from trustgraph.provenance import ( + agent_session_uri, + GRAPH_RETRIEVAL, +) + +from ..react.tools import ( + KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, + StructuredQueryImpl, RowEmbeddingsQueryImpl, ToolServiceImpl, +) +from ..react.agent_manager import AgentManager +from ..tool_filter import validate_tool_config +from ..react.types import Final, Action, Tool, Argument + +from . meta_router import MetaRouter +from . pattern_base import PatternBase, UserAwareContext +from . react_pattern import ReactPattern +from . plan_pattern import PlanThenExecutePattern +from . supervisor_pattern import SupervisorPattern +from . aggregator import Aggregator + +logger = logging.getLogger(__name__) + +default_ident = "agent-manager" +default_max_iterations = 10 +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue + + +class Processor(AgentService): + + def __init__(self, **params): + + id = params.get("id") + + self.max_iterations = int( + params.get("max_iterations", default_max_iterations) + ) + + self.config_key = params.get("config_type", "agent") + + super(Processor, self).__init__( + **params | { + "id": id, + "max_iterations": self.max_iterations, + "config_type": self.config_key, + } + ) + + self.agent = AgentManager( + tools={}, + additional_context="", + ) + + self.tool_service_clients = {} + + # Patterns + self.react_pattern = ReactPattern(self) + self.plan_pattern = PlanThenExecutePattern(self) + self.supervisor_pattern = SupervisorPattern(self) + + # Aggregator for supervisor fan-in + self.aggregator = Aggregator() + + # Meta-router (initialised on first config load) + self.meta_router = None + + self.config_handlers.append(self.on_tools_config) + + self.register_specification( + TextCompletionClientSpec( + request_name="text-completion-request", + response_name="text-completion-response", + ) + ) + + self.register_specification( + GraphRagClientSpec( + request_name="graph-rag-request", + response_name="graph-rag-response", + ) + ) + + self.register_specification( + PromptClientSpec( + request_name="prompt-request", + response_name="prompt-response", + ) + ) + + self.register_specification( + ToolClientSpec( + request_name="mcp-tool-request", + response_name="mcp-tool-response", + ) + ) + + self.register_specification( + StructuredQueryClientSpec( + request_name="structured-query-request", + response_name="structured-query-response", + ) + ) + + self.register_specification( + EmbeddingsClientSpec( + request_name="embeddings-request", + response_name="embeddings-response", + ) + ) + + self.register_specification( + RowEmbeddingsQueryClientSpec( + request_name="row-embeddings-query-request", + response_name="row-embeddings-query-response", + ) + ) + + # Explainability producer + self.register_specification( + ProducerSpec( + name="explainability", + schema=Triples, + ) + ) + + # Librarian client + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request" + ) + + self.librarian_request_producer = Producer( + backend=self.pubsub, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self.on_librarian_response, + metrics=librarian_response_metrics, + ) + + self.pending_librarian_requests = {} + + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + response = msg.value() + request_id = msg.properties().get("id") + + if request_id in self.pending_librarian_requests: + future = self.pending_librarian_requests.pop(request_id) + future.set_result(response) + + async def save_answer_content(self, doc_id, user, content, title=None, + timeout=120): + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or "Agent Answer", + document_type="answer", + ) + + request = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content.encode("utf-8")).decode("utf-8"), + user=user, + ) + + future = asyncio.get_event_loop().create_future() + self.pending_librarian_requests[request_id] = future + + try: + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving answer: " + f"{response.error.type}: {response.error.message}" + ) + return doc_id + + except asyncio.TimeoutError: + self.pending_librarian_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving answer document {doc_id}") + + def provenance_session_uri(self, session_id): + return agent_session_uri(session_id) + + async def on_tools_config(self, config, version): + + logger.info(f"Loading configuration version {version}") + + try: + tools = {} + + # Load tool-service configurations + tool_services = {} + if "tool-service" in config: + for service_id, service_value in config["tool-service"].items(): + service_data = json.loads(service_value) + tool_services[service_id] = service_data + logger.debug(f"Loaded tool-service config: {service_id}") + + logger.info( + f"Loaded {len(tool_services)} tool-service configurations" + ) + + # Load tool configurations + if "tool" in config: + for tool_id, tool_value in config["tool"].items(): + data = json.loads(tool_value) + impl_id = data.get("type") + name = data.get("name") + + if impl_id == "knowledge-query": + impl = functools.partial( + KnowledgeQueryImpl, + collection=data.get("collection"), + ) + arguments = KnowledgeQueryImpl.get_arguments() + elif impl_id == "text-completion": + impl = TextCompletionImpl + arguments = TextCompletionImpl.get_arguments() + elif impl_id == "mcp-tool": + config_args = data.get("arguments", []) + arguments = [ + Argument( + name=arg.get("name"), + type=arg.get("type"), + description=arg.get("description"), + ) + for arg in config_args + ] + impl = functools.partial( + McpToolImpl, + mcp_tool_id=data.get("mcp-tool"), + arguments=arguments, + ) + elif impl_id == "prompt": + config_args = data.get("arguments", []) + arguments = [ + Argument( + name=arg.get("name"), + type=arg.get("type"), + description=arg.get("description"), + ) + for arg in config_args + ] + impl = functools.partial( + PromptImpl, + template_id=data.get("template"), + arguments=arguments, + ) + elif impl_id == "structured-query": + impl = functools.partial( + StructuredQueryImpl, + collection=data.get("collection"), + user=None, + ) + arguments = StructuredQueryImpl.get_arguments() + elif impl_id == "row-embeddings-query": + impl = functools.partial( + RowEmbeddingsQueryImpl, + schema_name=data.get("schema-name"), + collection=data.get("collection"), + user=None, + index_name=data.get("index-name"), + limit=int(data.get("limit", 10)), + ) + arguments = RowEmbeddingsQueryImpl.get_arguments() + elif impl_id == "tool-service": + service_ref = data.get("service") + if not service_ref: + raise RuntimeError( + f"Tool {name} has type 'tool-service' " + f"but no 'service' reference" + ) + if service_ref not in tool_services: + raise RuntimeError( + f"Tool {name} references unknown " + f"tool-service '{service_ref}'" + ) + + service_config = tool_services[service_ref] + request_queue = service_config.get("request-queue") + response_queue = service_config.get("response-queue") + if not request_queue or not response_queue: + raise RuntimeError( + f"Tool-service '{service_ref}' must define " + f"'request-queue' and 'response-queue'" + ) + + config_params = service_config.get("config-params", []) + config_values = {} + for param in config_params: + param_name = ( + param.get("name") + if isinstance(param, dict) else param + ) + if param_name in data: + config_values[param_name] = data[param_name] + elif ( + isinstance(param, dict) + and param.get("required", False) + ): + raise RuntimeError( + f"Tool {name} missing required config " + f"param '{param_name}'" + ) + + config_args = data.get("arguments", []) + arguments = [ + Argument( + name=arg.get("name"), + type=arg.get("type"), + description=arg.get("description"), + ) + for arg in config_args + ] + + impl = functools.partial( + ToolServiceImpl, + request_queue=request_queue, + response_queue=response_queue, + config_values=config_values, + arguments=arguments, + processor=self, + ) + else: + raise RuntimeError( + f"Tool type {impl_id} not known" + ) + + validate_tool_config(data) + + tools[name] = Tool( + name=name, + description=data.get("description"), + implementation=impl, + config=data, + arguments=arguments, + ) + + # Load additional context from agent config + additional = None + if self.config_key in config: + agent_config = config[self.config_key] + additional = agent_config.get("additional-context", None) + + self.agent = AgentManager( + tools=tools, + additional_context=additional, + ) + + # Re-initialise meta-router with config + self.meta_router = MetaRouter(config=config) + + logger.info(f"Loaded {len(tools)} tools") + logger.info("Tool configuration reloaded.") + + except Exception as e: + logger.error( + f"on_tools_config Exception: {e}", exc_info=True + ) + logger.error("Configuration reload failed") + + async def agent_request(self, request, respond, next, flow): + + try: + pattern = getattr(request, 'pattern', '') or '' + + # If no pattern set and this is the first iteration, route + if not pattern and not request.history: + context = UserAwareContext(flow, request.user) + + if self.meta_router: + pattern, task_type, framing = await self.meta_router.route( + request.question, context, + ) + else: + pattern = "react" + task_type = "general" + framing = "" + + # Update request with routing decision + request.pattern = pattern + request.task_type = task_type + request.framing = framing + + logger.info( + f"Routed to pattern={pattern}, " + f"task_type={task_type}" + ) + + # Dispatch to the selected pattern + if pattern == "plan-then-execute": + await self.plan_pattern.iterate( + request, respond, next, flow, + ) + elif pattern == "supervisor": + await self.supervisor_pattern.iterate( + request, respond, next, flow, + ) + else: + # Default to react + await self.react_pattern.iterate( + request, respond, next, flow, + ) + + except Exception as e: + + logger.error( + f"agent_request Exception: {e}", exc_info=True + ) + + logger.debug("Send error response...") + + error_obj = Error( + type="agent-error", + message=str(e), + ) + + r = AgentResponse( + chunk_type="error", + content=str(e), + end_of_message=True, + end_of_dialog=True, + error=error_obj, + ) + + await respond(r) + + @staticmethod + def add_args(parser): + + AgentService.add_args(parser) + + parser.add_argument( + '--max-iterations', + default=default_max_iterations, + help=f'Maximum number of react iterations ' + f'(default: {default_max_iterations})', + ) + + parser.add_argument( + '--config-type', + default="agent", + help='Configuration key for prompts (default: agent)', + ) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py new file mode 100644 index 00000000..9070a393 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -0,0 +1,214 @@ +""" +SupervisorPattern — decomposes a query into subagent goals, fans out, +then synthesises results when all subagents complete. + +Phase 1 (decompose): LLM breaks the query into independent sub-goals. +Fan-out: Each sub-goal is emitted as a new AgentRequest on the agent + request topic, carrying a correlation_id and parent_session_id. +Phase 2 (synthesise): Triggered when the aggregator detects all + subagents have completed. The supervisor fetches results and + produces the final answer. +""" + +import json +import logging +import uuid + +from ... schema import AgentRequest, AgentResponse, AgentStep + +from ..react.types import Action, Final + +from . pattern_base import PatternBase + +logger = logging.getLogger(__name__) + +MAX_SUBAGENTS = 5 + + +class SupervisorPattern(PatternBase): + """ + Supervisor pattern: decompose, fan-out, synthesise. + + History tracks phase via AgentStep.step_type: + - "decompose": the decomposition step (subagent goals in arguments) + - "synthesise": triggered by aggregator with results in subagent_results + """ + + async def iterate(self, request, respond, next, flow): + + streaming = getattr(request, 'streaming', False) + session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) + collection = getattr(request, 'collection', 'default') + iteration_num = len(request.history) + 1 + session_uri = self.processor.provenance_session_uri(session_id) + + # Emit session provenance on first iteration + if iteration_num == 1: + await self.emit_session_triples( + flow, session_uri, request.question, + request.user, collection, respond, streaming, + ) + + logger.info( + f"SupervisorPattern iteration {iteration_num}: {request.question}" + ) + + # Check if this is a synthesis request (has subagent_results) + has_results = bool( + request.history + and any( + getattr(h, 'step_type', '') == 'decompose' + for h in request.history + ) + and any( + getattr(h, 'subagent_results', None) + for h in request.history + ) + ) + + if has_results: + await self._synthesise( + request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num, + ) + else: + await self._decompose_and_fanout( + request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num, + ) + + async def _decompose_and_fanout(self, request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num): + """Decompose the question into sub-goals and fan out subagents.""" + + think = self.make_think_callback(respond, streaming) + framing = getattr(request, 'framing', '') + + tools = self.filter_tools(self.processor.agent.tools, request) + + context = self.make_context(flow, request.user) + client = context("prompt-request") + + # Use the supervisor-decompose prompt template + goals = await client.prompt( + id="supervisor-decompose", + variables={ + "question": request.question, + "framing": framing, + "max_subagents": MAX_SUBAGENTS, + "tools": [ + {"name": t.name, "description": t.description} + for t in tools.values() + ], + }, + ) + + # Validate result + if not isinstance(goals, list): + goals = [] + goals = [g for g in goals if isinstance(g, str)] + goals = goals[:MAX_SUBAGENTS] + + if not goals: + goals = [request.question] + + await think( + f"Decomposed into {len(goals)} sub-goals: {goals}", + is_final=True, + ) + + # Generate correlation ID for this fan-out + correlation_id = str(uuid.uuid4()) + + # Emit decomposition provenance + decompose_act = Action( + thought=f"Decomposed into {len(goals)} sub-goals", + name="decompose", + arguments={"goals": json.dumps(goals), "correlation_id": correlation_id}, + observation=f"Fanning out {len(goals)} subagents", + ) + await self.emit_iteration_triples( + flow, session_id, iteration_num, session_uri, + decompose_act, request, respond, streaming, + ) + + # Fan out: emit a subagent request for each goal + for i, goal in enumerate(goals): + subagent_session = str(uuid.uuid4()) + sub_request = AgentRequest( + question=goal, + state="", + group=getattr(request, 'group', []), + history=[], + user=request.user, + collection=collection, + streaming=False, # Subagents don't stream + session_id=subagent_session, + conversation_id=getattr(request, 'conversation_id', ''), + pattern="react", # Subagents use react by default + task_type=getattr(request, 'task_type', ''), + framing=getattr(request, 'framing', ''), + correlation_id=correlation_id, + parent_session_id=session_id, + subagent_goal=goal, + expected_siblings=len(goals), + ) + await next(sub_request) + logger.info(f"Fan-out: emitted subagent {i} for goal: {goal}") + + # NOTE: The supervisor stops here. The aggregator will detect + # when all subagents complete and emit a synthesis request + # with the results populated. + logger.info( + f"Supervisor fan-out complete: {len(goals)} subagents, " + f"correlation_id={correlation_id}" + ) + + async def _synthesise(self, request, respond, next, flow, + session_id, collection, streaming, + session_uri, iteration_num): + """Synthesise final answer from subagent results.""" + + think = self.make_think_callback(respond, streaming) + framing = getattr(request, 'framing', '') + + # Collect subagent results from history + subagent_results = {} + for step in request.history: + results = getattr(step, 'subagent_results', None) + if results: + subagent_results.update(results) + + if not subagent_results: + logger.warning("Synthesis called with no subagent results") + subagent_results = {"(no results)": "No subagent results available"} + + context = self.make_context(flow, request.user) + client = context("prompt-request") + + await think("Synthesising final answer from sub-agent results", is_final=True) + + response_text = await self.prompt_as_answer( + client, "supervisor-synthesise", + variables={ + "question": request.question, + "framing": framing, + "results": [ + {"goal": goal, "result": result} + for goal, result in subagent_results.items() + ], + }, + respond=respond, + streaming=streaming, + ) + + await self.emit_final_triples( + flow, session_id, iteration_num, session_uri, + response_text, request, respond, streaming, + ) + await self.send_final_response( + respond, streaming, response_text, already_streamed=streaming, + ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index c037c937..1bca9627 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -485,25 +485,16 @@ class Processor(AgentService): logger.debug(f"Think: {x} (is_final={is_final})") if streaming: - # Streaming format r = AgentResponse( chunk_type="thought", content=x, end_of_message=is_final, end_of_dialog=False, - # Legacy fields for backward compatibility - answer=None, - error=None, - thought=x, - observation=None, ) else: - # Non-streaming format r = AgentResponse( - answer=None, - error=None, - thought=x, - observation=None, + chunk_type="thought", + content=x, end_of_message=True, end_of_dialog=False, ) @@ -515,25 +506,16 @@ class Processor(AgentService): logger.debug(f"Observe: {x} (is_final={is_final})") if streaming: - # Streaming format r = AgentResponse( chunk_type="observation", content=x, end_of_message=is_final, end_of_dialog=False, - # Legacy fields for backward compatibility - answer=None, - error=None, - thought=None, - observation=x, ) else: - # Non-streaming format r = AgentResponse( - answer=None, - error=None, - thought=None, - observation=x, + chunk_type="observation", + content=x, end_of_message=True, end_of_dialog=False, ) @@ -545,25 +527,16 @@ class Processor(AgentService): logger.debug(f"Answer: {x}") if streaming: - # Streaming format r = AgentResponse( chunk_type="answer", content=x, - end_of_message=False, # More chunks may follow + end_of_message=False, end_of_dialog=False, - # Legacy fields for backward compatibility - answer=None, - error=None, - thought=None, - observation=None, ) else: - # Non-streaming format - shouldn't normally be called r = AgentResponse( - answer=x, - error=None, - thought=None, - observation=None, + chunk_type="answer", + content=x, end_of_message=True, end_of_dialog=False, ) @@ -677,25 +650,17 @@ class Processor(AgentService): )) if streaming: - # Streaming format - send end-of-dialog marker - # Answer chunks were already sent via answer() callback during parsing + # End-of-dialog marker — answer chunks already sent via callback r = AgentResponse( chunk_type="answer", - content="", # Empty content, just marking end of dialog + content="", end_of_message=True, end_of_dialog=True, - # Legacy fields set to None - answer already sent via streaming chunks - answer=None, - error=None, - thought=None, ) else: - # Non-streaming format - send complete answer r = AgentResponse( - answer=act.final, - error=None, - thought=None, - observation=None, + chunk_type="answer", + content=f, end_of_message=True, end_of_dialog=True, ) @@ -833,21 +798,13 @@ class Processor(AgentService): # Check if streaming was enabled (may not be set if error occurred early) streaming = getattr(request, 'streaming', False) if 'request' in locals() else False - if streaming: - # Streaming format - r = AgentResponse( - chunk_type="error", - content=str(e), - end_of_message=True, - end_of_dialog=True, - # Legacy fields for backward compatibility - error=error_obj, - ) - else: - # Legacy format - r = AgentResponse( - error=error_obj, - ) + r = AgentResponse( + chunk_type="error", + content=str(e), + end_of_message=True, + end_of_dialog=True, + error=error_obj, + ) await respond(r) From 0781d3e6a7cdea204ca62813676bd9ef6d6d0f5a Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 09:12:33 +0100 Subject: [PATCH 15/37] Remove unnecessary prompt-client logging (#740) --- .../trustgraph/base/prompt_client.py | 23 +++---------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index 74b25132..6859a9f0 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -1,21 +1,16 @@ import json import asyncio -import logging from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import PromptRequest, PromptResponse -logger = logging.getLogger(__name__) - class PromptClient(RequestResponse): async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None): - logger.info(f"DEBUG prompt_client: prompt called, id={id}, streaming={streaming}, chunk_callback={chunk_callback is not None}") if not streaming: - logger.info("DEBUG prompt_client: Non-streaming path") - # Non-streaming path + resp = await self.request( PromptRequest( id = id, @@ -36,39 +31,30 @@ class PromptClient(RequestResponse): return json.loads(resp.object) else: - logger.info("DEBUG prompt_client: Streaming path") - # Streaming path - just forward chunks, don't accumulate + last_text = "" last_object = None async def forward_chunks(resp): nonlocal last_text, last_object - logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}") if resp.error: - logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}") raise RuntimeError(resp.error.message) end_stream = getattr(resp, 'end_of_stream', False) - # Always call callback if there's text OR if it's the final message if resp.text is not None: last_text = resp.text - # Call chunk callback if provided with both chunk and end_of_stream flag if chunk_callback: - logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}") if asyncio.iscoroutinefunction(chunk_callback): await chunk_callback(resp.text, end_stream) else: chunk_callback(resp.text, end_stream) elif resp.object: - logger.info(f"DEBUG prompt_client: Got object response") last_object = resp.object - logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}") return end_stream - logger.info("DEBUG prompt_client: Creating PromptRequest") req = PromptRequest( id = id, terms = { @@ -77,19 +63,16 @@ class PromptClient(RequestResponse): }, streaming = True ) - logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}") + await self.request( req, recipient=forward_chunks, timeout=timeout ) - logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}") if last_text: - logger.info("DEBUG prompt_client: Returning last_text") return last_text - logger.info("DEBUG prompt_client: Returning parsed last_object") return json.loads(last_object) if last_object else None async def extract_definitions(self, text, timeout=600): From 81ca7bbc11955542abe4833fbf87d9f9a063d479 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 09:35:58 +0100 Subject: [PATCH 16/37] Change monitor default to prompts-rag (#742) --- trustgraph-cli/trustgraph/cli/monitor_prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trustgraph-cli/trustgraph/cli/monitor_prompts.py b/trustgraph-cli/trustgraph/cli/monitor_prompts.py index c412b643..974cfbcd 100644 --- a/trustgraph-cli/trustgraph/cli/monitor_prompts.py +++ b/trustgraph-cli/trustgraph/cli/monitor_prompts.py @@ -24,7 +24,7 @@ from pulsar.schema import BytesSchema default_flow = "default" -default_queue_type = "prompt" +default_queue_type = "prompt-rag" default_max_lines = 3 default_max_width = 80 From e65ea217a207cce476c77edfdd85078dcea13a53 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 11:24:30 +0100 Subject: [PATCH 17/37] agent-orchestrator improvements (#743) agent-orchestrator improvements: - Improve agent trace - Improve queue dumping - Fixing supervisor pattern - Fix synthesis step to remove loop Minor dev environment improvements: - Improve queue dump output for JSON - Reduce dev container rebuild --- Makefile | 4 +- trustgraph-cli/trustgraph/cli/dump_queues.py | 84 ++++++++++------ .../agent/orchestrator/aggregator.py | 21 ++-- .../agent/orchestrator/pattern_base.py | 41 ++++++++ .../agent/orchestrator/plan_pattern.py | 10 +- .../agent/orchestrator/react_pattern.py | 9 +- .../trustgraph/agent/orchestrator/service.py | 65 ++++++++++++ .../agent/orchestrator/supervisor_pattern.py | 18 ++-- .../trustgraph/agent/react/agent_manager.py | 98 +++++-------------- 9 files changed, 225 insertions(+), 125 deletions(-) diff --git a/Makefile b/Makefile index 4d79f554..197a6c63 100644 --- a/Makefile +++ b/Makefile @@ -77,8 +77,8 @@ some-containers: -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . - ${DOCKER} build -f containers/Containerfile.unstructured \ - -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . +# ${DOCKER} build -f containers/Containerfile.unstructured \ +# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . # ${DOCKER} build -f containers/Containerfile.vertexai \ # -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . # ${DOCKER} build -f containers/Containerfile.mcp \ diff --git a/trustgraph-cli/trustgraph/cli/dump_queues.py b/trustgraph-cli/trustgraph/cli/dump_queues.py index 0a298450..4df61cc3 100644 --- a/trustgraph-cli/trustgraph/cli/dump_queues.py +++ b/trustgraph-cli/trustgraph/cli/dump_queues.py @@ -19,43 +19,67 @@ import argparse from trustgraph.base.subscriber import Subscriber from trustgraph.base.pubsub import get_pubsub +def decode_json_strings(obj): + """Recursively decode JSON-encoded string values within a dict/list.""" + if isinstance(obj, dict): + return {k: decode_json_strings(v) for k, v in obj.items()} + if isinstance(obj, list): + return [decode_json_strings(v) for v in obj] + if isinstance(obj, str): + try: + parsed = json.loads(obj) + if isinstance(parsed, (dict, list)): + return decode_json_strings(parsed) + except (json.JSONDecodeError, TypeError): + pass + return obj + + +def to_dict(value): + """Recursively convert a value to a JSON-serialisable structure.""" + + if value is None or isinstance(value, (bool, int, float)): + return value + + if isinstance(value, bytes): + value = value.decode('utf-8') + + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + + if isinstance(value, dict): + return {k: to_dict(v) for k, v in value.items()} + + if isinstance(value, (list, tuple)): + return [to_dict(v) for v in value] + + # Pulsar schema objects expose fields via __dict__ + if hasattr(value, '__dict__'): + return { + k: to_dict(v) for k, v in value.__dict__.items() + if not k.startswith('_') + } + + return str(value) + + def format_message(queue_name, msg): """Format a message with timestamp and queue name.""" timestamp = datetime.now().isoformat() - # Try to parse as JSON and pretty-print try: - # Handle both Message objects and raw bytes - if hasattr(msg, 'value'): - # Message object with .value() method - value = msg.value() - else: - # Raw bytes from schema-less subscription - value = msg + value = msg.value() if hasattr(msg, 'value') else msg + parsed = to_dict(value) - # If it's bytes, decode it - if isinstance(value, bytes): - value = value.decode('utf-8') - - # If it's a string, try to parse as JSON - if isinstance(value, str): - try: - parsed = json.loads(value) - body = json.dumps(parsed, indent=2) - except (json.JSONDecodeError, TypeError): - body = value + # Unwrap nested JSON strings (e.g. terms values) + if isinstance(parsed, (dict, list)): + parsed = decode_json_strings(parsed) + body = json.dumps(parsed, indent=2, default=str) else: - # Try to convert to dict for pretty printing - try: - # Pulsar schema objects have __dict__ or similar - if hasattr(value, '__dict__'): - parsed = {k: v for k, v in value.__dict__.items() - if not k.startswith('_')} - else: - parsed = str(value) - body = json.dumps(parsed, indent=2, default=str) - except (TypeError, AttributeError): - body = str(value) + body = str(parsed) except Exception as e: body = f"\n{str(msg)}" diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py index bff8822c..9187f21e 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py @@ -1,10 +1,12 @@ """ -Aggregator — monitors the explainability topic for subagent completions -and triggers synthesis when all siblings in a fan-out have completed. +Aggregator — tracks in-flight fan-out correlations and triggers +synthesis when all subagents have completed. -The aggregator watches for tg:Conclusion triples that carry a -correlation_id. When it detects that all expected siblings have -completed, it emits a synthesis AgentRequest on the agent request topic. +Subagent completions arrive as AgentRequest messages on the agent +request queue with step_type="subagent-completion". The orchestrator +intercepts these and feeds them to the aggregator. When all expected +siblings for a correlation ID have reported, the aggregator builds +a synthesis request for the supervisor pattern. """ import asyncio @@ -87,6 +89,13 @@ class Aggregator: return completed >= expected + def get_original_request(self, correlation_id): + """Peek at the stored request template without consuming it.""" + entry = self.correlations.get(correlation_id) + if entry is None: + return None + return entry["request_template"] + def get_results(self, correlation_id): """Get all results for a correlation and remove the tracking entry.""" entry = self.correlations.pop(correlation_id, None) @@ -138,7 +147,7 @@ class Aggregator: pattern="supervisor", task_type=template.task_type if template else "", framing=template.framing if template else "", - correlation_id=correlation_id, + correlation_id="", parent_session_id="", subagent_goal="", expected_siblings=0, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index fc07e745..b66bc4f5 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -61,6 +61,47 @@ class PatternBase: def __init__(self, processor): self.processor = processor + def is_subagent(self, request): + """Check if this request is running as a subagent of a supervisor.""" + return bool(getattr(request, 'correlation_id', '')) + + async def emit_subagent_completion(self, request, next, answer_text): + """Signal completion back to the orchestrator via the agent request + queue. Instead of sending the final answer to the client, send a + completion message so the aggregator can collect it.""" + + completion_step = AgentStep( + thought="Subagent completed", + action="complete", + arguments={}, + observation=answer_text, + step_type="subagent-completion", + ) + + completion_request = AgentRequest( + question=request.question, + state="", + group=getattr(request, 'group', []), + history=[completion_step], + user=request.user, + collection=getattr(request, 'collection', 'default'), + streaming=False, + session_id=getattr(request, 'session_id', ''), + conversation_id=getattr(request, 'conversation_id', ''), + pattern="", + correlation_id=request.correlation_id, + parent_session_id=getattr(request, 'parent_session_id', ''), + subagent_goal=getattr(request, 'subagent_goal', ''), + expected_siblings=getattr(request, 'expected_siblings', 0), + ) + + await next(completion_request) + logger.info( + f"Subagent completion emitted for " + f"correlation={request.correlation_id}, " + f"goal={getattr(request, 'subagent_goal', '')}" + ) + def filter_tools(self, tools, request): """Apply group/state filtering to the tool set.""" return filter_tools_by_group_and_state( diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index d5f667c8..4c61039f 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -344,6 +344,10 @@ class PlanThenExecutePattern(PatternBase): flow, session_id, iteration_num, session_uri, response_text, request, respond, streaming, ) - await self.send_final_response( - respond, streaming, response_text, already_streamed=streaming, - ) + + if self.is_subagent(request): + await self.emit_subagent_completion(request, next, response_text) + else: + await self.send_final_response( + respond, streaming, response_text, already_streamed=streaming, + ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index c0e481f7..a03dc194 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -106,9 +106,12 @@ class ReactPattern(PatternBase): f, request, respond, streaming, ) - await self.send_final_response( - respond, streaming, f, already_streamed=streaming, - ) + if self.is_subagent(request): + await self.emit_subagent_completion(request, next, f) + else: + await self.send_final_response( + respond, streaming, f, already_streamed=streaming, + ) return # Not final — emit iteration provenance and send next request diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index f7418e60..9c9980d4 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -422,9 +422,74 @@ class Processor(AgentService): ) logger.error("Configuration reload failed") + async def _handle_subagent_completion(self, request, respond, next, flow): + """Handle a subagent completion by feeding it to the aggregator.""" + + correlation_id = request.correlation_id + subagent_goal = getattr(request, 'subagent_goal', '') + + # Extract the answer from the completion step + answer_text = "" + for step in request.history: + if getattr(step, 'step_type', '') == 'subagent-completion': + answer_text = step.observation + break + + logger.info( + f"Received subagent completion: " + f"correlation={correlation_id}, goal={subagent_goal}" + ) + + all_done = self.aggregator.record_completion( + correlation_id, subagent_goal, answer_text + ) + + if all_done is None: + logger.warning( + f"Unknown correlation_id {correlation_id} — " + f"possibly timed out or duplicate" + ) + return + + if all_done: + logger.info( + f"All subagents complete for {correlation_id}, " + f"dispatching synthesis" + ) + + template = self.aggregator.get_original_request(correlation_id) + if template is None: + logger.error( + f"No template for correlation {correlation_id}" + ) + return + + synthesis_request = self.aggregator.build_synthesis_request( + correlation_id, + original_question=template.question, + user=template.user, + collection=getattr(template, 'collection', 'default'), + ) + + await next(synthesis_request) + async def agent_request(self, request, respond, next, flow): try: + + # Intercept subagent completion messages + correlation_id = getattr(request, 'correlation_id', '') + if correlation_id and request.history: + is_completion = any( + getattr(h, 'step_type', '') == 'subagent-completion' + for h in request.history + ) + if is_completion: + await self._handle_subagent_completion( + request, respond, next, flow + ) + return + pattern = getattr(request, 'pattern', '') or '' # If no pattern set and this is the first iteration, route diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 9070a393..51c2d500 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -57,11 +57,8 @@ class SupervisorPattern(PatternBase): has_results = bool( request.history and any( - getattr(h, 'step_type', '') == 'decompose' - for h in request.history - ) - and any( - getattr(h, 'subagent_results', None) + getattr(h, 'step_type', '') == 'synthesise' + and getattr(h, 'subagent_results', None) for h in request.history ) ) @@ -159,9 +156,14 @@ class SupervisorPattern(PatternBase): await next(sub_request) logger.info(f"Fan-out: emitted subagent {i} for goal: {goal}") - # NOTE: The supervisor stops here. The aggregator will detect - # when all subagents complete and emit a synthesis request - # with the results populated. + # Register with aggregator for fan-in tracking + self.processor.aggregator.register_fanout( + correlation_id=correlation_id, + parent_session_id=session_id, + expected_siblings=len(goals), + request_template=request, + ) + logger.info( f"Supervisor fan-out complete: {len(goals)} subagents, " f"correlation_id={correlation_id}" diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index 87cee33d..18598b38 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -16,37 +16,37 @@ class AgentManager: def parse_react_response(self, text): """Parse text-based ReAct response format. - + Expected format: Thought: [reasoning about what to do next] Action: [tool_name] Args: { "param": "value" } - + OR - + Thought: [reasoning about the final answer] Final Answer: [the answer] """ if not isinstance(text, str): raise ValueError(f"Expected string response, got {type(text)}") - + # Remove any markdown code blocks that might wrap the response text = re.sub(r'^```[^\n]*\n', '', text.strip()) text = re.sub(r'\n```$', '', text.strip()) - + lines = text.strip().split('\n') - + thought = None action = None args = None final_answer = None - + i = 0 while i < len(lines): line = lines[i].strip() - + # Parse Thought if line.startswith("Thought:"): thought = line[8:].strip() @@ -59,19 +59,19 @@ class AgentManager: thought += " " + next_line i += 1 continue - + # Parse Final Answer if line.startswith("Final Answer:"): final_answer = line[13:].strip() # Handle multi-line final answers (including JSON) i += 1 - + # Check if the answer might be JSON if final_answer.startswith('{') or (i < len(lines) and lines[i].strip().startswith('{')): # Collect potential JSON answer json_text = final_answer if final_answer.startswith('{') else "" brace_count = json_text.count('{') - json_text.count('}') - + while i < len(lines) and (brace_count > 0 or not json_text): current_line = lines[i].strip() if current_line.startswith(("Thought:", "Action:")) and brace_count == 0: @@ -79,7 +79,7 @@ class AgentManager: json_text += ("\n" if json_text else "") + current_line brace_count += current_line.count('{') - current_line.count('}') i += 1 - + # Try to parse as JSON # try: # final_answer = json.loads(json_text) @@ -95,13 +95,13 @@ class AgentManager: break final_answer += " " + next_line i += 1 - + # If we have a final answer, return Final object return Final( thought=thought or "", final=final_answer ) - + # Parse Action if line.startswith("Action:"): action = line[7:].strip() @@ -112,7 +112,7 @@ class AgentManager: while action and action[-1] == '"': action = action[:-1] - + # Parse Args if line.startswith("Args:"): # Check if JSON starts on the same line @@ -123,15 +123,15 @@ class AgentManager: else: args_text = "" brace_count = 0 - + # Collect all lines that form the JSON arguments i += 1 started = bool(args_on_same_line and '{' in args_on_same_line) - + while i < len(lines) and (not started or brace_count > 0): current_line = lines[i] args_text += ("\n" if args_text else "") + current_line - + # Count braces to determine when JSON is complete for char in current_line: if char == '{': @@ -139,22 +139,22 @@ class AgentManager: started = True elif char == '}': brace_count -= 1 - + # If we've started and braces are balanced, we're done if started and brace_count == 0: break - + i += 1 - + # Parse the JSON arguments try: args = json.loads(args_text.strip()) except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON arguments: {args_text}") raise ValueError(f"Invalid JSON in Args: {e}") - + i += 1 - + # If we have an action, return Action object if action: return Action( @@ -163,11 +163,11 @@ class AgentManager: arguments=args or {}, observation="" ) - + # If we only have a thought but no action or final answer if thought and not action and not final_answer: raise ValueError(f"Response has thought but no action or final answer: {text}") - + raise ValueError(f"Could not parse response: {text}") async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None): @@ -176,15 +176,10 @@ class AgentManager: tools = self.tools - logger.debug("in reason") - logger.debug(f"tools: {tools}") - tool_names = ",".join([ t for t in self.tools.keys() ]) - logger.debug(f"Tool names: {tool_names}") - variables = { "question": question, "tools": [ @@ -218,17 +213,10 @@ class AgentManager: logger.debug(f"Variables: {json.dumps(variables, indent=4)}") - logger.info(f"prompt: {variables}") - - logger.info(f"DEBUG: streaming={streaming}, think={think is not None}") - # Streaming path - use StreamingReActParser if streaming and think: - logger.info("DEBUG: Entering streaming path") from .streaming_parser import StreamingReActParser - logger.info("DEBUG: Creating StreamingReActParser") - # Collect chunks to send via async callbacks thought_chunks = [] answer_chunks = [] @@ -238,24 +226,19 @@ class AgentManager: on_thought_chunk=lambda chunk: thought_chunks.append(chunk), on_answer_chunk=lambda chunk: answer_chunks.append(chunk), ) - logger.info("DEBUG: StreamingReActParser created") # Create async chunk callback that feeds parser and sends collected chunks async def on_chunk(text, end_of_stream): - logger.info(f"DEBUG: on_chunk called with {len(text)} chars, end_of_stream={end_of_stream}") # Track what we had before prev_thought_count = len(thought_chunks) prev_answer_count = len(answer_chunks) # Feed the parser (synchronous) - logger.info(f"DEBUG: About to call parser.feed") parser.feed(text) - logger.info(f"DEBUG: parser.feed returned") # Send any new thought chunks for i in range(prev_thought_count, len(thought_chunks)): - logger.info(f"DEBUG: Sending thought chunk {i}") # Mark last chunk as final if parser has moved out of THOUGHT state is_last = (i == len(thought_chunks) - 1) is_thought_complete = parser.state.value != "thought" @@ -264,72 +247,52 @@ class AgentManager: # Send any new answer chunks for i in range(prev_answer_count, len(answer_chunks)): - logger.info(f"DEBUG: Sending answer chunk {i}") if answer: await answer(answer_chunks[i]) else: await think(answer_chunks[i]) - logger.info("DEBUG: Getting prompt-request client from context") client = context("prompt-request") - logger.info(f"DEBUG: Got client: {client}") - logger.info("DEBUG: About to call agent_react with streaming=True") # Get streaming response response_text = await client.agent_react( variables=variables, streaming=True, chunk_callback=on_chunk ) - logger.info(f"DEBUG: agent_react returned, got {len(response_text) if response_text else 0} chars") # Finalize parser - logger.info("DEBUG: Finalizing parser") parser.finalize() - logger.info("DEBUG: Parser finalized") # Get result - logger.info("DEBUG: Getting result from parser") result = parser.get_result() if result is None: raise RuntimeError("Parser failed to produce a result") - logger.info(f"Parsed result: {result}") return result else: - logger.info("DEBUG: Entering NON-streaming path") # Non-streaming path - get complete text and parse - logger.info("DEBUG: Getting prompt-request client from context") client = context("prompt-request") - logger.info(f"DEBUG: Got client: {client}") - logger.info("DEBUG: About to call agent_react with streaming=False") response_text = await client.agent_react( variables=variables, streaming=False ) - logger.info(f"DEBUG: agent_react returned, got response") logger.debug(f"Response text:\n{response_text}") - logger.info(f"response: {response_text}") - # Parse the text response try: result = self.parse_react_response(response_text) - logger.info(f"Parsed result: {result}") return result except ValueError as e: logger.error(f"Failed to parse response: {e}") - # Try to provide a helpful error message logger.error(f"Response was: {response_text}") raise RuntimeError(f"Failed to parse agent response: {e}") async def react(self, question, history, think, observe, context, streaming=False, answer=None): - logger.info(f"question: {question}") - act = await self.reason( question = question, history = history, @@ -339,7 +302,6 @@ class AgentManager: observe = observe, answer = answer, ) - logger.info(f"act: {act}") if isinstance(act, Final): @@ -358,16 +320,11 @@ class AgentManager: logger.debug(f"ACTION: {act.name}") - logger.debug(f"Tools: {self.tools.keys()}") - if act.name in self.tools: action = self.tools[act.name] else: - logger.debug(f"Tools: {self.tools}") raise RuntimeError(f"No action for {act.name}!") - logger.debug(f"TOOL>>> {act}") - resp = await action.implementation(context).invoke( **act.arguments ) @@ -378,13 +335,8 @@ class AgentManager: resp = str(resp) resp = resp.strip() - logger.info(f"resp: {resp}") - await observe(resp, is_final=True) act.observation = resp - logger.info(f"iter: {act}") - return act - From 7b734148b33b5d5a84d67b1494e7c0edf2592674 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 12:54:51 +0100 Subject: [PATCH 18/37] agent-orchestrator: add explainability provenance for all patterns (#744) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit agent-orchestrator: add explainability provenance for all agent patterns Extend the provenance/explainability system to provide human-readable reasoning traces for the orchestrator's three agent patterns. Previously only ReAct emitted provenance (session, iteration, conclusion). Now each pattern records its cognitive steps as typed RDF entities in the knowledge graph, using composable mixin types (e.g. Finding + Answer). New provenance chains: - Supervisor: Question → Decomposition → Finding ×N → Synthesis - Plan-then-Execute: Question → Plan → StepResult ×N → Synthesis - ReAct: Question → Analysis ×N → Conclusion (unchanged) New RDF types: Decomposition, Finding, Plan, StepResult. New predicates: tg:subagentGoal, tg:planStep. Reuses existing Synthesis + Answer mixin for final answers. Provenance library (trustgraph-base): - Triple builders, URI generators, vocabulary labels for new types - Client dataclasses with from_triples() dispatch - fetch_agent_trace() follows branching provenance chains - API exports updated Orchestrator (trustgraph-flow): - PatternBase emit methods for decomposition, finding, plan, step result, and synthesis - SupervisorPattern emits decomposition during fan-out - PlanThenExecutePattern emits plan and step results - Service emits finding triples on subagent completion - Synthesis provenance replaces generic final triples CLI (trustgraph-cli): - invoke_agent -x displays new entity types inline --- trustgraph-base/trustgraph/api/__init__.py | 4 + .../trustgraph/api/explainability.py | 201 +++++++++++++----- .../trustgraph/provenance/__init__.py | 30 +++ .../trustgraph/provenance/agent.py | 107 +++++++++- .../trustgraph/provenance/namespaces.py | 10 +- trustgraph-base/trustgraph/provenance/uris.py | 25 +++ .../trustgraph/provenance/vocabulary.py | 8 + trustgraph-cli/trustgraph/cli/invoke_agent.py | 34 +++ .../agent/orchestrator/pattern_base.py | 150 +++++++++++++ .../agent/orchestrator/plan_pattern.py | 32 +-- .../trustgraph/agent/orchestrator/service.py | 16 +- .../agent/orchestrator/supervisor_pattern.py | 25 ++- 12 files changed, 560 insertions(+), 82 deletions(-) diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index dc1405ac..e956db65 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -82,6 +82,10 @@ from .explainability import ( Reflection, Analysis, Conclusion, + Decomposition, + Finding, + Plan, + StepResult, EdgeSelection, wire_triples_to_tuples, extract_term_value, diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index 1c986efb..7b406a59 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -44,6 +44,16 @@ TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" TG_AGENT_QUESTION = TG + "AgentQuestion" +# Orchestrator entity types +TG_DECOMPOSITION = TG + "Decomposition" +TG_FINDING = TG + "Finding" +TG_PLAN_TYPE = TG + "Plan" +TG_STEP_RESULT = TG + "StepResult" + +# Orchestrator predicates +TG_SUBAGENT_GOAL = TG + "subagentGoal" +TG_PLAN_STEP = TG + "planStep" + # PROV-O predicates PROV = "http://www.w3.org/ns/prov#" PROV_STARTED_AT_TIME = PROV + "startedAtTime" @@ -82,6 +92,14 @@ class ExplainEntity: return Exploration.from_triples(uri, triples) elif TG_FOCUS in types: return Focus.from_triples(uri, triples) + elif TG_DECOMPOSITION in types: + return Decomposition.from_triples(uri, triples) + elif TG_FINDING in types: + return Finding.from_triples(uri, triples) + elif TG_PLAN_TYPE in types: + return Plan.from_triples(uri, triples) + elif TG_STEP_RESULT in types: + return StepResult.from_triples(uri, triples) elif TG_SYNTHESIS in types: return Synthesis.from_triples(uri, triples) elif TG_REFLECTION_TYPE in types: @@ -314,6 +332,70 @@ class Conclusion(ExplainEntity): ) +@dataclass +class Decomposition(ExplainEntity): + """Decomposition entity - supervisor broke question into sub-goals.""" + goals: List[str] = field(default_factory=list) + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Decomposition": + goals = [] + for s, p, o in triples: + if p == TG_SUBAGENT_GOAL: + goals.append(o) + return cls(uri=uri, entity_type="decomposition", goals=goals) + + +@dataclass +class Finding(ExplainEntity): + """Finding entity - a subagent's result.""" + goal: str = "" + document: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Finding": + goal = "" + document = "" + for s, p, o in triples: + if p == TG_SUBAGENT_GOAL: + goal = o + elif p == TG_DOCUMENT: + document = o + return cls(uri=uri, entity_type="finding", goal=goal, document=document) + + +@dataclass +class Plan(ExplainEntity): + """Plan entity - a structured plan of steps.""" + steps: List[str] = field(default_factory=list) + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Plan": + steps = [] + for s, p, o in triples: + if p == TG_PLAN_STEP: + steps.append(o) + return cls(uri=uri, entity_type="plan", steps=steps) + + +@dataclass +class StepResult(ExplainEntity): + """StepResult entity - a plan step's result.""" + step: str = "" + document: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "StepResult": + step = "" + document = "" + for s, p, o in triples: + if p == TG_PLAN_STEP: + step = o + elif p == TG_DOCUMENT: + document = o + return cls(uri=uri, entity_type="step-result", step=step, document=document) + + def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSelection: """Parse triples for an edge selection entity.""" uri = triples[0][0] if triples else "" @@ -895,7 +977,10 @@ class ExplainabilityClient: """ Fetch the complete Agent trace starting from a session URI. - Follows the provenance chain: Question -> Analysis(s) -> Conclusion + Follows the provenance chain for all patterns: + - ReAct: Question -> Analysis(s) -> Conclusion + - Supervisor: Question -> Decomposition -> Finding(s) -> Synthesis + - Plan-then-Execute: Question -> Plan -> StepResult(s) -> Synthesis Args: session_uri: The agent session/question URI @@ -906,14 +991,15 @@ class ExplainabilityClient: max_content: Maximum content length for conclusion Returns: - Dict with question, iterations (Analysis list), conclusion entities + Dict with question, steps (mixed entity list), conclusion/synthesis """ if graph is None: graph = "urn:graph:retrieval" trace = { "question": None, - "iterations": [], + "steps": [], + "iterations": [], # Backwards compatibility for ReAct "conclusion": None, } @@ -923,64 +1009,79 @@ class ExplainabilityClient: return trace trace["question"] = question - # Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after - current_uri = session_uri - is_first = True - max_iterations = 50 # Safety limit + # Follow the provenance chain from the question + self._follow_provenance_chain( + session_uri, trace, graph, user, collection, + is_first=True, max_depth=50, + ) - for _ in range(max_iterations): - # First hop uses wasGeneratedBy (entity←activity), - # subsequent hops use wasDerivedFrom (entity←entity) - if is_first: - derived_triples = self.flow.triples_query( - p=PROV_WAS_GENERATED_BY, - o=current_uri, - g=graph, - user=user, - collection=collection, - limit=10 - ) - # Fall back to wasDerivedFrom for backwards compatibility - if not derived_triples: - derived_triples = self.flow.triples_query( - p=PROV_WAS_DERIVED_FROM, - o=current_uri, - g=graph, - user=user, - collection=collection, - limit=10 - ) - is_first = False - else: + # Backwards compat: populate iterations from steps + trace["iterations"] = [ + s for s in trace["steps"] if isinstance(s, Analysis) + ] + + return trace + + def _follow_provenance_chain( + self, current_uri, trace, graph, user, collection, + is_first=False, max_depth=50, + ): + """Recursively follow the provenance chain, handling branches.""" + if max_depth <= 0: + return + + # Find entities derived from current_uri + if is_first: + derived_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=current_uri, + g=graph, user=user, collection=collection, + limit=20 + ) + if not derived_triples: derived_triples = self.flow.triples_query( p=PROV_WAS_DERIVED_FROM, o=current_uri, - g=graph, - user=user, - collection=collection, - limit=10 + g=graph, user=user, collection=collection, + limit=20 ) + else: + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=current_uri, + g=graph, user=user, collection=collection, + limit=20 + ) - if not derived_triples: - break + if not derived_triples: + return - derived_uri = extract_term_value(derived_triples[0].get("s", {})) + derived_uris = [ + extract_term_value(t.get("s", {})) + for t in derived_triples + ] + + for derived_uri in derived_uris: if not derived_uri: - break + continue entity = self.fetch_entity(derived_uri, graph, user, collection) + if entity is None: + continue - if isinstance(entity, Analysis): - trace["iterations"].append(entity) - current_uri = derived_uri - elif isinstance(entity, Conclusion): + if isinstance(entity, (Analysis, Decomposition, Finding, + Plan, StepResult)): + trace["steps"].append(entity) + + # Continue following from this entity + self._follow_provenance_chain( + derived_uri, trace, graph, user, collection, + max_depth=max_depth - 1, + ) + + elif isinstance(entity, (Conclusion, Synthesis)): + trace["steps"].append(entity) trace["conclusion"] = entity - break - else: - # Unknown entity type, stop - break - - return trace def list_sessions( self, @@ -1082,7 +1183,7 @@ class ExplainabilityClient: for child_uri in all_child_uris: entity = self.fetch_entity(child_uri, graph, user, collection) - if isinstance(entity, Analysis): + if isinstance(entity, (Analysis, Decomposition, Plan)): return "agent" if isinstance(entity, Exploration): return "graphrag" diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index ac52c5e5..304f17a7 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -53,6 +53,12 @@ from . uris import ( agent_thought_uri, agent_observation_uri, agent_final_uri, + # Orchestrator provenance URIs + agent_decomposition_uri, + agent_finding_uri, + agent_plan_uri, + agent_step_result_uri, + agent_synthesis_uri, # Document RAG provenance URIs docrag_question_uri, docrag_grounding_uri, @@ -94,6 +100,9 @@ from . namespaces import ( TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, # Agent provenance predicates TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + TG_SUBAGENT_GOAL, TG_PLAN_STEP, + # Orchestrator entity types + TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, # Document reference predicate TG_DOCUMENT, # Named graphs @@ -124,6 +133,12 @@ from . agent import ( agent_session_triples, agent_iteration_triples, agent_final_triples, + # Orchestrator provenance triple builders + agent_decomposition_triples, + agent_finding_triples, + agent_plan_triples, + agent_step_result_triples, + agent_synthesis_triples, ) # Vocabulary bootstrap @@ -159,6 +174,12 @@ __all__ = [ "agent_thought_uri", "agent_observation_uri", "agent_final_uri", + # Orchestrator provenance URIs + "agent_decomposition_uri", + "agent_finding_uri", + "agent_plan_uri", + "agent_step_result_uri", + "agent_synthesis_uri", # Document RAG provenance URIs "docrag_question_uri", "docrag_grounding_uri", @@ -193,6 +214,9 @@ __all__ = [ "TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION", # Agent provenance predicates "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", + "TG_SUBAGENT_GOAL", "TG_PLAN_STEP", + # Orchestrator entity types + "TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT", # Document reference predicate "TG_DOCUMENT", # Named graphs @@ -215,6 +239,12 @@ __all__ = [ "agent_session_triples", "agent_iteration_triples", "agent_final_triples", + # Orchestrator provenance triple builders + "agent_decomposition_triples", + "agent_finding_triples", + "agent_plan_triples", + "agent_step_result_triples", + "agent_synthesis_triples", # Utility "set_graph", # Vocabulary diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py index f1aeab0d..d25109a7 100644 --- a/trustgraph-base/trustgraph/provenance/agent.py +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -1,10 +1,15 @@ """ Helper functions to build PROV-O triples for agent provenance. -Agent provenance tracks the reasoning trace of ReAct agent sessions: +Agent provenance tracks the reasoning trace of agent sessions: - Question: The root activity with query and timestamp -- Analysis: Each think/act/observe cycle -- Conclusion: The final answer +- Analysis: Each think/act/observe cycle (ReAct) +- Conclusion: The final answer (ReAct) +- Decomposition: Supervisor broke question into sub-goals +- Finding: A subagent's result (Supervisor) +- Plan: Structured plan of steps (Plan-then-Execute) +- StepResult: A plan step's result (Plan-then-Execute) +- Synthesis: Final synthesised answer (Supervisor, Plan-then-Execute) """ import json @@ -21,6 +26,8 @@ from . namespaces import ( TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_AGENT_QUESTION, + TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, + TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP, ) @@ -203,3 +210,97 @@ def agent_final_triples( triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) return triples + + +def agent_decomposition_triples( + uri: str, + session_uri: str, + goals: List[str], +) -> List[Triple]: + """Build triples for a supervisor decomposition step.""" + triples = [ + _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(uri, RDF_TYPE, _iri(TG_DECOMPOSITION)), + _triple(uri, RDFS_LABEL, + _literal(f"Decomposed into {len(goals)} research threads")), + _triple(uri, PROV_WAS_GENERATED_BY, _iri(session_uri)), + ] + for goal in goals: + triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal))) + return triples + + +def agent_finding_triples( + uri: str, + decomposition_uri: str, + goal: str, + document_id: Optional[str] = None, +) -> List[Triple]: + """Build triples for a subagent finding.""" + triples = [ + _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(uri, RDF_TYPE, _iri(TG_FINDING)), + _triple(uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), + _triple(uri, RDFS_LABEL, _literal(f"Finding: {goal[:60]}")), + _triple(uri, PROV_WAS_DERIVED_FROM, _iri(decomposition_uri)), + _triple(uri, TG_SUBAGENT_GOAL, _literal(goal)), + ] + if document_id: + triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) + return triples + + +def agent_plan_triples( + uri: str, + session_uri: str, + steps: List[str], +) -> List[Triple]: + """Build triples for a plan-then-execute plan.""" + triples = [ + _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(uri, RDF_TYPE, _iri(TG_PLAN_TYPE)), + _triple(uri, RDFS_LABEL, + _literal(f"Plan with {len(steps)} steps")), + _triple(uri, PROV_WAS_GENERATED_BY, _iri(session_uri)), + ] + for step in steps: + triples.append(_triple(uri, TG_PLAN_STEP, _literal(step))) + return triples + + +def agent_step_result_triples( + uri: str, + plan_uri: str, + goal: str, + document_id: Optional[str] = None, +) -> List[Triple]: + """Build triples for a plan step result.""" + triples = [ + _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(uri, RDF_TYPE, _iri(TG_STEP_RESULT)), + _triple(uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), + _triple(uri, RDFS_LABEL, _literal(f"Step result: {goal[:60]}")), + _triple(uri, PROV_WAS_DERIVED_FROM, _iri(plan_uri)), + _triple(uri, TG_PLAN_STEP, _literal(goal)), + ] + if document_id: + triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) + return triples + + +def agent_synthesis_triples( + uri: str, + previous_uri: str, + document_id: Optional[str] = None, +) -> List[Triple]: + """Build triples for a synthesis answer.""" + triples = [ + _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(uri, RDF_TYPE, _iri(TG_SYNTHESIS)), + _triple(uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), + _triple(uri, RDFS_LABEL, _literal("Synthesis")), + _triple(uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri)), + ] + if document_id: + triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) + return triples diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 3dc16fa2..69134dfb 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -94,8 +94,14 @@ TG_SYNTHESIS = TG + "Synthesis" TG_ANALYSIS = TG + "Analysis" TG_CONCLUSION = TG + "Conclusion" +# Orchestrator entity types +TG_DECOMPOSITION = TG + "Decomposition" # Supervisor decomposed into sub-goals +TG_FINDING = TG + "Finding" # Subagent result +TG_PLAN_TYPE = TG + "Plan" # Plan-then-execute plan +TG_STEP_RESULT = TG + "StepResult" # Plan step result + # Unifying types for answer and intermediate commentary -TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion) +TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion, Finding, StepResult) TG_REFLECTION_TYPE = TG + "Reflection" # Intermediate commentary (Thought, Observation) TG_THOUGHT_TYPE = TG + "Thought" # Agent reasoning TG_OBSERVATION_TYPE = TG + "Observation" # Agent tool result @@ -110,6 +116,8 @@ TG_THOUGHT = TG + "thought" # Links iteration to thought sub-entity TG_ACTION = TG + "action" TG_ARGUMENTS = TG + "arguments" TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity +TG_SUBAGENT_GOAL = TG + "subagentGoal" # Goal string on Decomposition/Finding +TG_PLAN_STEP = TG + "planStep" # Step goal string on Plan/StepResult # Named graph URIs for RDF datasets # These separate different types of data while keeping them in the same collection diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index ac221515..a3aadef6 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -234,6 +234,31 @@ def agent_final_uri(session_id: str) -> str: return f"urn:trustgraph:agent:{session_id}/final" +def agent_decomposition_uri(session_id: str) -> str: + """Generate URI for a supervisor decomposition step.""" + return f"urn:trustgraph:agent:{session_id}/decompose" + + +def agent_finding_uri(session_id: str, index: int) -> str: + """Generate URI for a subagent finding.""" + return f"urn:trustgraph:agent:{session_id}/finding/{index}" + + +def agent_plan_uri(session_id: str) -> str: + """Generate URI for a plan-then-execute plan.""" + return f"urn:trustgraph:agent:{session_id}/plan" + + +def agent_step_result_uri(session_id: str, index: int) -> str: + """Generate URI for a plan step result.""" + return f"urn:trustgraph:agent:{session_id}/step/{index}" + + +def agent_synthesis_uri(session_id: str) -> str: + """Generate URI for a synthesis answer.""" + return f"urn:trustgraph:agent:{session_id}/synthesis" + + # Document RAG provenance URIs # These URIs use the urn:trustgraph:docrag: namespace to distinguish # document RAG provenance from graph RAG provenance diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py index 018e2bfe..afb5c30f 100644 --- a/trustgraph-base/trustgraph/provenance/vocabulary.py +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -27,6 +27,8 @@ from . namespaces import ( TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, TG_CONCEPT, TG_ENTITY, TG_GROUNDING, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, + TG_SUBAGENT_GOAL, TG_PLAN_STEP, ) @@ -87,6 +89,10 @@ TG_CLASS_LABELS = [ _label_triple(TG_REFLECTION_TYPE, "Reflection"), _label_triple(TG_THOUGHT_TYPE, "Thought"), _label_triple(TG_OBSERVATION_TYPE, "Observation"), + _label_triple(TG_DECOMPOSITION, "Decomposition"), + _label_triple(TG_FINDING, "Finding"), + _label_triple(TG_PLAN_TYPE, "Plan"), + _label_triple(TG_STEP_RESULT, "Step Result"), ] # TrustGraph predicate labels @@ -109,6 +115,8 @@ TG_PREDICATE_LABELS = [ _label_triple(TG_SOURCE_CHAR_LENGTH, "source character length"), _label_triple(TG_CONCEPT, "concept"), _label_triple(TG_ENTITY, "entity"), + _label_triple(TG_SUBAGENT_GOAL, "subagent goal"), + _label_triple(TG_PLAN_STEP, "plan step"), ] diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 9879025f..2a1ba7c2 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -13,6 +13,11 @@ from trustgraph.api import ( Question, Analysis, Conclusion, + Decomposition, + Finding, + Plan, + StepResult, + Synthesis, AgentThought, AgentObservation, AgentAnswer, @@ -209,6 +214,35 @@ def question_explainable( if entity.observation: print(f" Observation: {entity.observation}", file=sys.stderr) + elif isinstance(entity, Decomposition): + print(f"\n [decompose] {prov_id}", file=sys.stderr) + for i, goal in enumerate(entity.goals): + print(f" Thread {i}: {goal}", file=sys.stderr) + + elif isinstance(entity, Finding): + print(f"\n [finding] {prov_id}", file=sys.stderr) + if entity.goal: + print(f" Goal: {entity.goal}", file=sys.stderr) + if entity.document: + print(f" Document: {entity.document}", file=sys.stderr) + + elif isinstance(entity, Plan): + print(f"\n [plan] {prov_id}", file=sys.stderr) + for i, step in enumerate(entity.steps): + print(f" Step {i}: {step}", file=sys.stderr) + + elif isinstance(entity, StepResult): + print(f"\n [step-result] {prov_id}", file=sys.stderr) + if entity.step: + print(f" Step: {entity.step}", file=sys.stderr) + if entity.document: + print(f" Document: {entity.document}", file=sys.stderr) + + elif isinstance(entity, Synthesis): + print(f"\n [synthesis] {prov_id}", file=sys.stderr) + if entity.document: + print(f" Document: {entity.document}", file=sys.stderr) + elif isinstance(entity, Conclusion): print(f"\n [conclusion] {prov_id}", file=sys.stderr) if entity.document: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index b66bc4f5..ddc4aed9 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -20,9 +20,19 @@ from trustgraph.provenance import ( agent_thought_uri, agent_observation_uri, agent_final_uri, + agent_decomposition_uri, + agent_finding_uri, + agent_plan_uri, + agent_step_result_uri, + agent_synthesis_uri, agent_session_triples, agent_iteration_triples, agent_final_triples, + agent_decomposition_triples, + agent_finding_triples, + agent_plan_triples, + agent_step_result_triples, + agent_synthesis_triples, set_graph, GRAPH_RETRIEVAL, ) @@ -359,6 +369,146 @@ class PatternBase: explain_graph=GRAPH_RETRIEVAL, )) + # ---- Orchestrator provenance helpers ------------------------------------ + + async def emit_decomposition_triples( + self, flow, session_id, session_uri, goals, user, collection, + respond, streaming, + ): + """Emit provenance for a supervisor decomposition step.""" + uri = agent_decomposition_uri(session_id) + triples = set_graph( + agent_decomposition_triples(uri, session_uri, goals), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata(id=uri, user=user, collection=collection), + triples=triples, + )) + if streaming: + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) + + async def emit_finding_triples( + self, flow, session_id, index, goal, answer_text, user, collection, + respond, streaming, + ): + """Emit provenance for a subagent finding.""" + uri = agent_finding_uri(session_id, index) + decomposition_uri = agent_decomposition_uri(session_id) + + doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc" + try: + await self.processor.save_answer_content( + doc_id=doc_id, user=user, + content=answer_text, + title=f"Finding: {goal[:60]}", + ) + except Exception as e: + logger.warning(f"Failed to save finding to librarian: {e}") + doc_id = None + + triples = set_graph( + agent_finding_triples(uri, decomposition_uri, goal, doc_id), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata(id=uri, user=user, collection=collection), + triples=triples, + )) + if streaming: + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) + + async def emit_plan_triples( + self, flow, session_id, session_uri, steps, user, collection, + respond, streaming, + ): + """Emit provenance for a plan creation.""" + uri = agent_plan_uri(session_id) + triples = set_graph( + agent_plan_triples(uri, session_uri, steps), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata(id=uri, user=user, collection=collection), + triples=triples, + )) + if streaming: + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) + + async def emit_step_result_triples( + self, flow, session_id, index, goal, answer_text, user, collection, + respond, streaming, + ): + """Emit provenance for a plan step result.""" + uri = agent_step_result_uri(session_id, index) + plan_uri = agent_plan_uri(session_id) + + doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc" + try: + await self.processor.save_answer_content( + doc_id=doc_id, user=user, + content=answer_text, + title=f"Step result: {goal[:60]}", + ) + except Exception as e: + logger.warning(f"Failed to save step result to librarian: {e}") + doc_id = None + + triples = set_graph( + agent_step_result_triples(uri, plan_uri, goal, doc_id), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata(id=uri, user=user, collection=collection), + triples=triples, + )) + if streaming: + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) + + async def emit_synthesis_triples( + self, flow, session_id, previous_uri, answer_text, user, collection, + respond, streaming, + ): + """Emit provenance for a synthesis answer.""" + uri = agent_synthesis_uri(session_id) + + doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc" + try: + await self.processor.save_answer_content( + doc_id=doc_id, user=user, + content=answer_text, + title="Synthesis", + ) + except Exception as e: + logger.warning(f"Failed to save synthesis to librarian: {e}") + doc_id = None + + triples = set_graph( + agent_synthesis_triples(uri, previous_uri, doc_id), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata(id=uri, user=user, collection=collection), + triples=triples, + )) + if streaming: + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) + # ---- Response helpers --------------------------------------------------- async def prompt_as_answer(self, client, prompt_id, variables, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 4c61039f..d6abb058 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -11,7 +11,7 @@ import uuid from ... schema import AgentRequest, AgentResponse, AgentStep, PlanStep -from ..react.types import Action + from . pattern_base import PatternBase @@ -126,6 +126,13 @@ class PlanThenExecutePattern(PatternBase): thought_text = f"Created plan with {len(plan_steps)} steps" await think(thought_text, is_final=True) + # Emit plan provenance + step_goals = [ps.get("goal", "") for ps in plan_steps] + await self.emit_plan_triples( + flow, session_id, session_uri, step_goals, + request.user, collection, respond, streaming, + ) + # Build PlanStep objects plan_agent_steps = [ PlanStep( @@ -263,16 +270,10 @@ class PlanThenExecutePattern(PatternBase): result=step_result, ) - # Emit iteration provenance - prov_act = Action( - thought=f"Plan step {pending_idx}: {goal}", - name=tool_name, - arguments=tool_arguments, - observation=step_result, - ) - await self.emit_iteration_triples( - flow, session_id, iteration_num, session_uri, - prov_act, request, respond, streaming, + # Emit step result provenance + await self.emit_step_result_triples( + flow, session_id, pending_idx, goal, step_result, + request.user, collection, respond, streaming, ) # Build execution step for history @@ -340,9 +341,12 @@ class PlanThenExecutePattern(PatternBase): streaming=streaming, ) - await self.emit_final_triples( - flow, session_id, iteration_num, session_uri, - response_text, request, respond, streaming, + # Emit synthesis provenance (links back to last step result) + from trustgraph.provenance import agent_step_result_uri + last_step_uri = agent_step_result_uri(session_id, len(plan) - 1) + await self.emit_synthesis_triples( + flow, session_id, last_step_uri, + response_text, request.user, collection, respond, streaming, ) if self.is_subagent(request): diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 9c9980d4..9ca3fe59 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -427,6 +427,7 @@ class Processor(AgentService): correlation_id = request.correlation_id subagent_goal = getattr(request, 'subagent_goal', '') + parent_session_id = getattr(request, 'parent_session_id', '') # Extract the answer from the completion step answer_text = "" @@ -451,13 +452,26 @@ class Processor(AgentService): ) return + # Emit finding provenance for this subagent + template = self.aggregator.get_original_request(correlation_id) + if template and parent_session_id: + entry = self.aggregator.correlations.get(correlation_id) + finding_index = len(entry["results"]) - 1 if entry else 0 + collection = getattr(template, 'collection', 'default') + + await self.supervisor_pattern.emit_finding_triples( + flow, parent_session_id, finding_index, + subagent_goal, answer_text, + template.user, collection, + respond, template.streaming, + ) + if all_done: logger.info( f"All subagents complete for {correlation_id}, " f"dispatching synthesis" ) - template = self.aggregator.get_original_request(correlation_id) if template is None: logger.error( f"No template for correlation {correlation_id}" diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 51c2d500..8588e400 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -16,7 +16,7 @@ import uuid from ... schema import AgentRequest, AgentResponse, AgentStep -from ..react.types import Action, Final +from trustgraph.provenance import agent_finding_uri from . pattern_base import PatternBase @@ -121,15 +121,9 @@ class SupervisorPattern(PatternBase): correlation_id = str(uuid.uuid4()) # Emit decomposition provenance - decompose_act = Action( - thought=f"Decomposed into {len(goals)} sub-goals", - name="decompose", - arguments={"goals": json.dumps(goals), "correlation_id": correlation_id}, - observation=f"Fanning out {len(goals)} subagents", - ) - await self.emit_iteration_triples( - flow, session_id, iteration_num, session_uri, - decompose_act, request, respond, streaming, + await self.emit_decomposition_triples( + flow, session_id, session_uri, goals, + request.user, collection, respond, streaming, ) # Fan out: emit a subagent request for each goal @@ -207,10 +201,15 @@ class SupervisorPattern(PatternBase): streaming=streaming, ) - await self.emit_final_triples( - flow, session_id, iteration_num, session_uri, - response_text, request, respond, streaming, + # Emit synthesis provenance (links back to last finding) + last_finding_uri = agent_finding_uri( + session_id, len(subagent_results) - 1 ) + await self.emit_synthesis_triples( + flow, session_id, last_finding_uri, + response_text, request.user, collection, respond, streaming, + ) + await self.send_final_response( respond, streaming, response_text, already_streamed=streaming, ) From 816a8cfcf68fd0b0f88cf6f3201ed3c53da797d6 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 13:12:26 +0100 Subject: [PATCH 19/37] Update tests for agent-orchestrator (#745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add 96 tests covering the orchestrator's aggregation, provenance, routing, and explainability parsing. These verify the supervisor fan-out/fan-in lifecycle, the new RDF provenance types (Decomposition, Finding, Plan, StepResult, Synthesis), and their round-trip through the wire format. Unit tests (84): - Aggregator: register, record completion, peek, build synthesis, cleanup - Provenance triple builders: types, provenance links, goals/steps, labels - Explainability parsing: from_triples dispatch, field extraction for all new entity types, precedence over existing types - PatternBase: is_subagent detection, emit_subagent_completion message shape - Completion dispatch: detection logic, full aggregator integration flow, synthesis request not re-intercepted as completion - MetaRouter: task type identification, pattern selection, valid_patterns constraints, fallback on LLM error or unknown response Contract tests (12): - Orchestration fields on AgentRequest round-trip correctly - subagent-completion and synthesise step types in request history - Plan steps with status and dependencies - Provenance triple builder → wire format → from_triples round-trip for all five new entity types --- tests/contract/test_orchestrator_contracts.py | 177 +++++++++++ tests/contract/test_provenance_wire_format.py | 129 ++++++++ tests/unit/test_agent/test_aggregator.py | 216 +++++++++++++ .../test_agent/test_completion_dispatch.py | 174 +++++++++++ .../test_agent/test_explainability_parsing.py | 162 ++++++++++ tests/unit/test_agent/test_meta_router.py | 289 ++++++++++++++++++ .../test_agent/test_pattern_base_subagent.py | 144 +++++++++ .../test_agent/test_provenance_triples.py | 226 ++++++++++++++ 8 files changed, 1517 insertions(+) create mode 100644 tests/contract/test_orchestrator_contracts.py create mode 100644 tests/contract/test_provenance_wire_format.py create mode 100644 tests/unit/test_agent/test_aggregator.py create mode 100644 tests/unit/test_agent/test_completion_dispatch.py create mode 100644 tests/unit/test_agent/test_explainability_parsing.py create mode 100644 tests/unit/test_agent/test_meta_router.py create mode 100644 tests/unit/test_agent/test_pattern_base_subagent.py create mode 100644 tests/unit/test_agent/test_provenance_triples.py diff --git a/tests/contract/test_orchestrator_contracts.py b/tests/contract/test_orchestrator_contracts.py new file mode 100644 index 00000000..ab168ece --- /dev/null +++ b/tests/contract/test_orchestrator_contracts.py @@ -0,0 +1,177 @@ +""" +Contract tests for orchestrator message schemas. + +Verifies that AgentRequest/AgentStep with orchestration fields +serialise and deserialise correctly through the Pulsar schema layer. +""" + +import pytest +import json + +from trustgraph.schema import AgentRequest, AgentStep, PlanStep + + +@pytest.mark.contract +class TestOrchestrationFieldContracts: + """Contract tests for orchestration fields on AgentRequest.""" + + def test_agent_request_orchestration_fields_roundtrip(self): + req = AgentRequest( + question="Test question", + user="testuser", + collection="default", + correlation_id="corr-123", + parent_session_id="parent-sess", + subagent_goal="What is X?", + expected_siblings=4, + pattern="react", + task_type="research", + framing="Focus on accuracy", + conversation_id="conv-456", + ) + + assert req.correlation_id == "corr-123" + assert req.parent_session_id == "parent-sess" + assert req.subagent_goal == "What is X?" + assert req.expected_siblings == 4 + assert req.pattern == "react" + assert req.task_type == "research" + assert req.framing == "Focus on accuracy" + assert req.conversation_id == "conv-456" + + def test_agent_request_orchestration_fields_default_empty(self): + req = AgentRequest( + question="Test question", + user="testuser", + ) + + assert req.correlation_id == "" + assert req.parent_session_id == "" + assert req.subagent_goal == "" + assert req.expected_siblings == 0 + assert req.pattern == "" + assert req.task_type == "" + assert req.framing == "" + + +@pytest.mark.contract +class TestSubagentCompletionStepContract: + """Contract tests for subagent-completion step type.""" + + def test_subagent_completion_step_fields(self): + step = AgentStep( + thought="Subagent completed", + action="complete", + arguments={}, + observation="The answer text", + step_type="subagent-completion", + ) + + assert step.step_type == "subagent-completion" + assert step.observation == "The answer text" + assert step.thought == "Subagent completed" + assert step.action == "complete" + + def test_subagent_completion_in_request_history(self): + step = AgentStep( + thought="Subagent completed", + action="complete", + arguments={}, + observation="answer", + step_type="subagent-completion", + ) + req = AgentRequest( + question="goal", + user="testuser", + correlation_id="corr-123", + history=[step], + ) + + assert len(req.history) == 1 + assert req.history[0].step_type == "subagent-completion" + assert req.history[0].observation == "answer" + + +@pytest.mark.contract +class TestSynthesisStepContract: + """Contract tests for synthesis step type with subagent_results.""" + + def test_synthesis_step_with_results(self): + results = {"goal-a": "answer-a", "goal-b": "answer-b"} + step = AgentStep( + thought="All subagents completed", + action="aggregate", + arguments={}, + observation=json.dumps(results), + step_type="synthesise", + subagent_results=results, + ) + + assert step.step_type == "synthesise" + assert step.subagent_results == results + assert json.loads(step.observation) == results + + def test_synthesis_request_matches_supervisor_expectations(self): + """The synthesis request built by the aggregator must be + recognisable by SupervisorPattern._synthesise().""" + results = {"goal-a": "answer-a", "goal-b": "answer-b"} + step = AgentStep( + thought="All subagents completed", + action="aggregate", + arguments={}, + observation=json.dumps(results), + step_type="synthesise", + subagent_results=results, + ) + + req = AgentRequest( + question="Original question", + user="testuser", + pattern="supervisor", + correlation_id="", + session_id="parent-sess", + history=[step], + ) + + # SupervisorPattern checks for step_type='synthesise' with + # subagent_results + has_results = bool( + req.history + and any( + getattr(h, 'step_type', '') == 'synthesise' + and getattr(h, 'subagent_results', None) + for h in req.history + ) + ) + assert has_results + + # Pattern must be supervisor + assert req.pattern == "supervisor" + + # Correlation ID must be empty (not re-intercepted) + assert req.correlation_id == "" + + +@pytest.mark.contract +class TestPlanStepContract: + """Contract tests for plan steps in history.""" + + def test_plan_step_in_history(self): + plan = [ + PlanStep(goal="Step 1", tool_hint="knowledge-query", + depends_on=[], status="completed", result="done"), + PlanStep(goal="Step 2", tool_hint="", + depends_on=[0], status="pending", result=""), + ] + step = AgentStep( + thought="Created plan", + action="plan", + step_type="plan", + plan=plan, + ) + + assert step.step_type == "plan" + assert len(step.plan) == 2 + assert step.plan[0].goal == "Step 1" + assert step.plan[0].status == "completed" + assert step.plan[1].depends_on == [0] diff --git a/tests/contract/test_provenance_wire_format.py b/tests/contract/test_provenance_wire_format.py new file mode 100644 index 00000000..de195f68 --- /dev/null +++ b/tests/contract/test_provenance_wire_format.py @@ -0,0 +1,129 @@ +""" +Contract tests for provenance triple wire format — verifies that triples +built by the provenance library can be parsed by the explainability API +through the wire format conversion. +""" + +import pytest + +from trustgraph.schema import IRI, LITERAL + +from trustgraph.provenance import ( + agent_decomposition_triples, + agent_finding_triples, + agent_plan_triples, + agent_step_result_triples, + agent_synthesis_triples, +) + +from trustgraph.api.explainability import ( + ExplainEntity, + Decomposition, + Finding, + Plan, + StepResult, + Synthesis, + wire_triples_to_tuples, +) + + +def _triples_to_wire(triples): + """Convert provenance Triple objects to the wire format dicts + that the gateway/socket client would produce.""" + wire = [] + for t in triples: + entry = { + "s": _term_to_wire(t.s), + "p": _term_to_wire(t.p), + "o": _term_to_wire(t.o), + } + wire.append(entry) + return wire + + +def _term_to_wire(term): + """Convert a Term to wire format dict.""" + if term.type == IRI: + return {"t": "i", "i": term.iri} + elif term.type == LITERAL: + return {"t": "l", "v": term.value} + return {"t": "l", "v": str(term)} + + +def _roundtrip(triples, uri): + """Convert triples through wire format and parse via from_triples.""" + wire = _triples_to_wire(triples) + tuples = wire_triples_to_tuples(wire) + return ExplainEntity.from_triples(uri, tuples) + + +@pytest.mark.contract +class TestDecompositionWireFormat: + + def test_roundtrip(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", + ["What is X?", "What is Y?"], + ) + entity = _roundtrip(triples, "urn:decompose") + + assert isinstance(entity, Decomposition) + assert set(entity.goals) == {"What is X?", "What is Y?"} + + +@pytest.mark.contract +class TestFindingWireFormat: + + def test_roundtrip(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + document_id="urn:doc/finding", + ) + entity = _roundtrip(triples, "urn:finding") + + assert isinstance(entity, Finding) + assert entity.goal == "What is X?" + assert entity.document == "urn:doc/finding" + + +@pytest.mark.contract +class TestPlanWireFormat: + + def test_roundtrip(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", + ["Step 1", "Step 2", "Step 3"], + ) + entity = _roundtrip(triples, "urn:plan") + + assert isinstance(entity, Plan) + assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"} + + +@pytest.mark.contract +class TestStepResultWireFormat: + + def test_roundtrip(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + document_id="urn:doc/step", + ) + entity = _roundtrip(triples, "urn:step") + + assert isinstance(entity, StepResult) + assert entity.step == "Define X" + assert entity.document == "urn:doc/step" + + +@pytest.mark.contract +class TestSynthesisWireFormat: + + def test_roundtrip(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + document_id="urn:doc/synthesis", + ) + entity = _roundtrip(triples, "urn:synthesis") + + assert isinstance(entity, Synthesis) + assert entity.document == "urn:doc/synthesis" diff --git a/tests/unit/test_agent/test_aggregator.py b/tests/unit/test_agent/test_aggregator.py new file mode 100644 index 00000000..afb19499 --- /dev/null +++ b/tests/unit/test_agent/test_aggregator.py @@ -0,0 +1,216 @@ +""" +Unit tests for the Aggregator — tracks fan-out correlations and triggers +synthesis when all subagents complete. +""" + +import time +import pytest + +from trustgraph.schema import AgentRequest, AgentStep + +from trustgraph.agent.orchestrator.aggregator import Aggregator + + +def _make_request(question="Test question", user="testuser", + collection="default", streaming=False, + session_id="parent-session", task_type="research", + framing="test framing", conversation_id="conv-1"): + return AgentRequest( + question=question, + user=user, + collection=collection, + streaming=streaming, + session_id=session_id, + task_type=task_type, + framing=framing, + conversation_id=conversation_id, + ) + + +class TestRegisterFanout: + + def test_stores_correlation_entry(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 3) + + assert "corr-1" in agg.correlations + entry = agg.correlations["corr-1"] + assert entry["parent_session_id"] == "parent-1" + assert entry["expected"] == 3 + assert entry["results"] == {} + + def test_stores_request_template(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + + entry = agg.correlations["corr-1"] + assert entry["request_template"] is template + + def test_records_creation_time(self): + agg = Aggregator() + before = time.time() + agg.register_fanout("corr-1", "parent-1", 2) + after = time.time() + + created = agg.correlations["corr-1"]["created_at"] + assert before <= created <= after + + +class TestRecordCompletion: + + def test_returns_false_until_all_done(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 3) + + assert agg.record_completion("corr-1", "goal-a", "answer-a") is False + assert agg.record_completion("corr-1", "goal-b", "answer-b") is False + assert agg.record_completion("corr-1", "goal-c", "answer-c") is True + + def test_returns_none_for_unknown_correlation(self): + agg = Aggregator() + result = agg.record_completion("unknown", "goal", "answer") + assert result is None + + def test_stores_results_by_goal(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 2) + + agg.record_completion("corr-1", "goal-a", "answer-a") + agg.record_completion("corr-1", "goal-b", "answer-b") + + results = agg.correlations["corr-1"]["results"] + assert results["goal-a"] == "answer-a" + assert results["goal-b"] == "answer-b" + + def test_single_subagent(self): + agg = Aggregator() + agg.register_fanout("corr-1", "parent-1", 1) + + assert agg.record_completion("corr-1", "goal-a", "answer") is True + + +class TestGetOriginalRequest: + + def test_peeks_without_consuming(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + + result = agg.get_original_request("corr-1") + assert result is template + # Entry still exists + assert "corr-1" in agg.correlations + + def test_returns_none_for_unknown(self): + agg = Aggregator() + assert agg.get_original_request("unknown") is None + + +class TestBuildSynthesisRequest: + + def test_builds_correct_request(self): + agg = Aggregator() + template = _make_request( + question="Original question", + streaming=True, + task_type="risk-assessment", + framing="Assess risks", + ) + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + agg.record_completion("corr-1", "goal-a", "answer-a") + agg.record_completion("corr-1", "goal-b", "answer-b") + + req = agg.build_synthesis_request( + "corr-1", + original_question="Original question", + user="testuser", + collection="default", + ) + + assert req.question == "Original question" + assert req.pattern == "supervisor" + assert req.session_id == "parent-1" + assert req.correlation_id == "" # Must be empty + assert req.streaming == True + assert req.task_type == "risk-assessment" + assert req.framing == "Assess risks" + + def test_synthesis_step_in_history(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 2, + request_template=template) + agg.record_completion("corr-1", "goal-a", "answer-a") + agg.record_completion("corr-1", "goal-b", "answer-b") + + req = agg.build_synthesis_request( + "corr-1", "question", "user", "default", + ) + + # Last history step should be the synthesis step + assert len(req.history) >= 1 + synth_step = req.history[-1] + assert synth_step.step_type == "synthesise" + assert synth_step.subagent_results == { + "goal-a": "answer-a", + "goal-b": "answer-b", + } + + def test_consumes_correlation_entry(self): + agg = Aggregator() + template = _make_request() + agg.register_fanout("corr-1", "parent-1", 1, + request_template=template) + agg.record_completion("corr-1", "goal-a", "answer-a") + + agg.build_synthesis_request( + "corr-1", "question", "user", "default", + ) + + # Entry should be removed + assert "corr-1" not in agg.correlations + + def test_raises_for_unknown_correlation(self): + agg = Aggregator() + with pytest.raises(RuntimeError, match="No results"): + agg.build_synthesis_request( + "unknown", "question", "user", "default", + ) + + +class TestCleanupStale: + + def test_removes_entries_older_than_timeout(self): + agg = Aggregator(timeout=1) + agg.register_fanout("corr-1", "parent-1", 2) + + # Backdate the creation time + agg.correlations["corr-1"]["created_at"] = time.time() - 2 + + stale = agg.cleanup_stale() + assert "corr-1" in stale + assert "corr-1" not in agg.correlations + + def test_keeps_recent_entries(self): + agg = Aggregator(timeout=300) + agg.register_fanout("corr-1", "parent-1", 2) + + stale = agg.cleanup_stale() + assert stale == [] + assert "corr-1" in agg.correlations + + def test_mixed_stale_and_fresh(self): + agg = Aggregator(timeout=1) + agg.register_fanout("stale", "parent-1", 2) + agg.register_fanout("fresh", "parent-2", 2) + + agg.correlations["stale"]["created_at"] = time.time() - 2 + + stale = agg.cleanup_stale() + assert "stale" in stale + assert "stale" not in agg.correlations + assert "fresh" in agg.correlations diff --git a/tests/unit/test_agent/test_completion_dispatch.py b/tests/unit/test_agent/test_completion_dispatch.py new file mode 100644 index 00000000..8c01f126 --- /dev/null +++ b/tests/unit/test_agent/test_completion_dispatch.py @@ -0,0 +1,174 @@ +""" +Unit tests for completion dispatch — verifies that agent_request() in the +orchestrator service correctly intercepts subagent completion messages and +routes them to _handle_subagent_completion. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.schema import AgentRequest, AgentStep + +from trustgraph.agent.orchestrator.aggregator import Aggregator + + +def _make_request(**kwargs): + defaults = dict( + question="Test question", + user="testuser", + collection="default", + ) + defaults.update(kwargs) + return AgentRequest(**defaults) + + +def _make_completion_request(correlation_id, goal, answer): + """Build a completion request as emit_subagent_completion would.""" + step = AgentStep( + thought="Subagent completed", + action="complete", + arguments={}, + observation=answer, + step_type="subagent-completion", + ) + return _make_request( + correlation_id=correlation_id, + parent_session_id="parent-sess", + subagent_goal=goal, + expected_siblings=2, + history=[step], + ) + + +class TestCompletionDetection: + """Test that completion messages are correctly identified.""" + + def test_is_completion_when_correlation_id_and_step_type(self): + req = _make_completion_request("corr-1", "goal-a", "answer-a") + + has_correlation = bool(getattr(req, 'correlation_id', '')) + is_completion = any( + getattr(h, 'step_type', '') == 'subagent-completion' + for h in req.history + ) + + assert has_correlation + assert is_completion + + def test_not_completion_without_correlation_id(self): + step = AgentStep( + step_type="subagent-completion", + observation="answer", + ) + req = _make_request( + correlation_id="", + history=[step], + ) + + has_correlation = bool(getattr(req, 'correlation_id', '')) + assert not has_correlation + + def test_not_completion_without_step_type(self): + step = AgentStep( + step_type="react", + observation="answer", + ) + req = _make_request( + correlation_id="corr-1", + history=[step], + ) + + is_completion = any( + getattr(h, 'step_type', '') == 'subagent-completion' + for h in req.history + ) + assert not is_completion + + def test_not_completion_with_empty_history(self): + req = _make_request( + correlation_id="corr-1", + history=[], + ) + assert not req.history + + +class TestAggregatorIntegration: + """Test the aggregator flow as used by _handle_subagent_completion.""" + + def test_full_completion_flow(self): + """Simulates the flow: register, record completions, build synthesis.""" + agg = Aggregator() + template = _make_request( + question="Original question", + streaming=True, + task_type="risk-assessment", + framing="Assess risks", + session_id="parent-sess", + ) + + # Register fan-out + agg.register_fanout("corr-1", "parent-sess", 2, + request_template=template) + + # First completion — not all done + all_done = agg.record_completion( + "corr-1", "goal-a", "answer-a", + ) + assert all_done is False + + # Second completion — all done + all_done = agg.record_completion( + "corr-1", "goal-b", "answer-b", + ) + assert all_done is True + + # Peek at template + peeked = agg.get_original_request("corr-1") + assert peeked.question == "Original question" + + # Build synthesis request + synth = agg.build_synthesis_request( + "corr-1", + original_question="Original question", + user="testuser", + collection="default", + ) + + # Verify synthesis request + assert synth.pattern == "supervisor" + assert synth.correlation_id == "" + assert synth.session_id == "parent-sess" + assert synth.streaming is True + + # Verify synthesis history has results + synth_steps = [ + s for s in synth.history + if getattr(s, 'step_type', '') == 'synthesise' + ] + assert len(synth_steps) == 1 + assert synth_steps[0].subagent_results == { + "goal-a": "answer-a", + "goal-b": "answer-b", + } + + def test_synthesis_request_not_detected_as_completion(self): + """The synthesis request must not be intercepted as a completion.""" + agg = Aggregator() + template = _make_request(session_id="parent-sess") + agg.register_fanout("corr-1", "parent-sess", 1, + request_template=template) + agg.record_completion("corr-1", "goal", "answer") + + synth = agg.build_synthesis_request( + "corr-1", "question", "user", "default", + ) + + # correlation_id must be empty so it's not intercepted + assert synth.correlation_id == "" + + # Even if we check for completion step, shouldn't match + is_completion = any( + getattr(h, 'step_type', '') == 'subagent-completion' + for h in synth.history + ) + assert not is_completion diff --git a/tests/unit/test_agent/test_explainability_parsing.py b/tests/unit/test_agent/test_explainability_parsing.py new file mode 100644 index 00000000..e09a7f1f --- /dev/null +++ b/tests/unit/test_agent/test_explainability_parsing.py @@ -0,0 +1,162 @@ +""" +Unit tests for explainability API parsing — verifies that from_triples() +correctly dispatches and parses the new orchestrator entity types. +""" + +import pytest + +from trustgraph.api.explainability import ( + ExplainEntity, + Decomposition, + Finding, + Plan, + StepResult, + Synthesis, + Analysis, + Conclusion, + TG_DECOMPOSITION, + TG_FINDING, + TG_PLAN_TYPE, + TG_STEP_RESULT, + TG_SYNTHESIS, + TG_ANSWER_TYPE, + TG_ANALYSIS, + TG_CONCLUSION, + TG_DOCUMENT, + TG_SUBAGENT_GOAL, + TG_PLAN_STEP, + RDF_TYPE, +) + +PROV_ENTITY = "http://www.w3.org/ns/prov#Entity" + + +def _make_triples(uri, types, extras=None): + """Build a list of (s, p, o) tuples for testing.""" + triples = [(uri, RDF_TYPE, t) for t in types] + if extras: + triples.extend((uri, p, o) for p, o in extras) + return triples + + +class TestFromTriplesDispatch: + + def test_dispatches_decomposition(self): + triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION]) + entity = ExplainEntity.from_triples("urn:d", triples) + assert isinstance(entity, Decomposition) + + def test_dispatches_finding(self): + triples = _make_triples("urn:f", + [PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:f", triples) + assert isinstance(entity, Finding) + + def test_dispatches_plan(self): + triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE]) + entity = ExplainEntity.from_triples("urn:p", triples) + assert isinstance(entity, Plan) + + def test_dispatches_step_result(self): + triples = _make_triples("urn:sr", + [PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:sr", triples) + assert isinstance(entity, StepResult) + + def test_dispatches_synthesis(self): + triples = _make_triples("urn:s", + [PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:s", triples) + assert isinstance(entity, Synthesis) + + def test_dispatches_analysis_unchanged(self): + triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS]) + entity = ExplainEntity.from_triples("urn:a", triples) + assert isinstance(entity, Analysis) + + def test_dispatches_conclusion_unchanged(self): + triples = _make_triples("urn:c", + [PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:c", triples) + assert isinstance(entity, Conclusion) + + def test_finding_takes_precedence_over_synthesis(self): + """Finding has Answer mixin but should dispatch to Finding, not + Synthesis, because Finding is checked first.""" + triples = _make_triples("urn:f", + [PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE]) + entity = ExplainEntity.from_triples("urn:f", triples) + assert isinstance(entity, Finding) + assert not isinstance(entity, Synthesis) + + +class TestDecompositionParsing: + + def test_parses_goals(self): + triples = _make_triples("urn:d", [TG_DECOMPOSITION], [ + (TG_SUBAGENT_GOAL, "What is X?"), + (TG_SUBAGENT_GOAL, "What is Y?"), + ]) + entity = Decomposition.from_triples("urn:d", triples) + assert set(entity.goals) == {"What is X?", "What is Y?"} + + def test_entity_type_field(self): + triples = _make_triples("urn:d", [TG_DECOMPOSITION]) + entity = Decomposition.from_triples("urn:d", triples) + assert entity.entity_type == "decomposition" + + def test_empty_goals(self): + triples = _make_triples("urn:d", [TG_DECOMPOSITION]) + entity = Decomposition.from_triples("urn:d", triples) + assert entity.goals == [] + + +class TestFindingParsing: + + def test_parses_goal_and_document(self): + triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [ + (TG_SUBAGENT_GOAL, "What is X?"), + (TG_DOCUMENT, "urn:doc/finding"), + ]) + entity = Finding.from_triples("urn:f", triples) + assert entity.goal == "What is X?" + assert entity.document == "urn:doc/finding" + + def test_entity_type_field(self): + triples = _make_triples("urn:f", [TG_FINDING]) + entity = Finding.from_triples("urn:f", triples) + assert entity.entity_type == "finding" + + +class TestPlanParsing: + + def test_parses_steps(self): + triples = _make_triples("urn:p", [TG_PLAN_TYPE], [ + (TG_PLAN_STEP, "Define X"), + (TG_PLAN_STEP, "Research Y"), + (TG_PLAN_STEP, "Analyse Z"), + ]) + entity = Plan.from_triples("urn:p", triples) + assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"} + + def test_entity_type_field(self): + triples = _make_triples("urn:p", [TG_PLAN_TYPE]) + entity = Plan.from_triples("urn:p", triples) + assert entity.entity_type == "plan" + + +class TestStepResultParsing: + + def test_parses_step_and_document(self): + triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [ + (TG_PLAN_STEP, "Define X"), + (TG_DOCUMENT, "urn:doc/step"), + ]) + entity = StepResult.from_triples("urn:sr", triples) + assert entity.step == "Define X" + assert entity.document == "urn:doc/step" + + def test_entity_type_field(self): + triples = _make_triples("urn:sr", [TG_STEP_RESULT]) + entity = StepResult.from_triples("urn:sr", triples) + assert entity.entity_type == "step-result" diff --git a/tests/unit/test_agent/test_meta_router.py b/tests/unit/test_agent/test_meta_router.py new file mode 100644 index 00000000..da0c634c --- /dev/null +++ b/tests/unit/test_agent/test_meta_router.py @@ -0,0 +1,289 @@ +""" +Unit tests for the MetaRouter — task type identification and pattern selection. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.agent.orchestrator.meta_router import ( + MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE, +) + + +def _make_config(patterns=None, task_types=None): + """Build a config dict as the config service would provide.""" + config = {} + if patterns: + config["agent-pattern"] = { + pid: json.dumps(pdata) for pid, pdata in patterns.items() + } + if task_types: + config["agent-task-type"] = { + tid: json.dumps(tdata) for tid, tdata in task_types.items() + } + return config + + +def _make_context(prompt_response): + """Build a mock context that returns a mock prompt client.""" + client = AsyncMock() + client.prompt = AsyncMock(return_value=prompt_response) + + def context(service_name): + return client + + return context + + +SAMPLE_PATTERNS = { + "react": {"name": "react", "description": "ReAct pattern"}, + "plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"}, + "supervisor": {"name": "supervisor", "description": "Supervisor pattern"}, +} + +SAMPLE_TASK_TYPES = { + "general": { + "name": "general", + "description": "General queries", + "valid_patterns": ["react", "plan-then-execute", "supervisor"], + "framing": "", + }, + "research": { + "name": "research", + "description": "Research queries", + "valid_patterns": ["react", "plan-then-execute"], + "framing": "Focus on gathering information.", + }, + "summarisation": { + "name": "summarisation", + "description": "Summarisation queries", + "valid_patterns": ["react"], + "framing": "Focus on concise synthesis.", + }, +} + + +class TestMetaRouterInit: + + def test_defaults_when_no_config(self): + router = MetaRouter() + assert "react" in router.patterns + assert "general" in router.task_types + + def test_loads_patterns_from_config(self): + config = _make_config(patterns=SAMPLE_PATTERNS) + router = MetaRouter(config=config) + assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"} + + def test_loads_task_types_from_config(self): + config = _make_config(task_types=SAMPLE_TASK_TYPES) + router = MetaRouter(config=config) + assert set(router.task_types.keys()) == {"general", "research", "summarisation"} + + def test_handles_invalid_json_in_config(self): + config = { + "agent-pattern": {"react": "not valid json"}, + } + router = MetaRouter(config=config) + assert "react" in router.patterns + assert router.patterns["react"]["name"] == "react" + + +class TestIdentifyTaskType: + + @pytest.mark.asyncio + async def test_skips_llm_when_single_task_type(self): + router = MetaRouter() # Only "general" + context = _make_context("should not be called") + + task_type, framing = await router.identify_task_type( + "test question", context, + ) + + assert task_type == "general" + + @pytest.mark.asyncio + async def test_uses_llm_when_multiple_task_types(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("research") + + task_type, framing = await router.identify_task_type( + "Research the topic", context, + ) + + assert task_type == "research" + assert framing == "Focus on gathering information." + + @pytest.mark.asyncio + async def test_handles_llm_returning_quoted_type(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context('"summarisation"') + + task_type, _ = await router.identify_task_type( + "Summarise this", context, + ) + + assert task_type == "summarisation" + + @pytest.mark.asyncio + async def test_falls_back_on_unknown_type(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("nonexistent-type") + + task_type, _ = await router.identify_task_type( + "test question", context, + ) + + assert task_type == DEFAULT_TASK_TYPE + + @pytest.mark.asyncio + async def test_falls_back_on_llm_error(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + + client = AsyncMock() + client.prompt = AsyncMock(side_effect=RuntimeError("LLM down")) + context = lambda name: client + + task_type, _ = await router.identify_task_type( + "test question", context, + ) + + assert task_type == DEFAULT_TASK_TYPE + + +class TestSelectPattern: + + @pytest.mark.asyncio + async def test_skips_llm_when_single_valid_pattern(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("should not be called") + + # summarisation only has ["react"] + pattern = await router.select_pattern( + "Summarise this", "summarisation", context, + ) + + assert pattern == "react" + + @pytest.mark.asyncio + async def test_uses_llm_when_multiple_valid_patterns(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("plan-then-execute") + + # research has ["react", "plan-then-execute"] + pattern = await router.select_pattern( + "Research this", "research", context, + ) + + assert pattern == "plan-then-execute" + + @pytest.mark.asyncio + async def test_respects_valid_patterns_constraint(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + # LLM returns supervisor, but research doesn't allow it + context = _make_context("supervisor") + + pattern = await router.select_pattern( + "Research this", "research", context, + ) + + # Should fall back to first valid pattern + assert pattern == "react" + + @pytest.mark.asyncio + async def test_falls_back_on_llm_error(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + + client = AsyncMock() + client.prompt = AsyncMock(side_effect=RuntimeError("LLM down")) + context = lambda name: client + + # general has ["react", "plan-then-execute", "supervisor"] + pattern = await router.select_pattern( + "test", "general", context, + ) + + # Falls back to first valid pattern + assert pattern == "react" + + @pytest.mark.asyncio + async def test_falls_back_to_default_for_unknown_task_type(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + context = _make_context("react") + + # Unknown task type — valid_patterns falls back to all patterns + pattern = await router.select_pattern( + "test", "unknown-type", context, + ) + + assert pattern == "react" + + +class TestRoute: + + @pytest.mark.asyncio + async def test_full_routing_pipeline(self): + config = _make_config( + patterns=SAMPLE_PATTERNS, + task_types=SAMPLE_TASK_TYPES, + ) + router = MetaRouter(config=config) + + # Mock context where prompt returns different values per call + client = AsyncMock() + call_count = 0 + + async def mock_prompt(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "research" # task type + return "plan-then-execute" # pattern + + client.prompt = mock_prompt + context = lambda name: client + + pattern, task_type, framing = await router.route( + "Research the relationships", context, + ) + + assert task_type == "research" + assert pattern == "plan-then-execute" + assert framing == "Focus on gathering information." diff --git a/tests/unit/test_agent/test_pattern_base_subagent.py b/tests/unit/test_agent/test_pattern_base_subagent.py new file mode 100644 index 00000000..1523b592 --- /dev/null +++ b/tests/unit/test_agent/test_pattern_base_subagent.py @@ -0,0 +1,144 @@ +""" +Unit tests for PatternBase subagent helpers — is_subagent() and +emit_subagent_completion(). +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock +from dataclasses import dataclass + +from trustgraph.schema import AgentRequest + +from trustgraph.agent.orchestrator.pattern_base import PatternBase + + +@dataclass +class MockProcessor: + """Minimal processor mock for PatternBase.""" + pass + + +def _make_request(**kwargs): + defaults = dict( + question="Test question", + user="testuser", + collection="default", + ) + defaults.update(kwargs) + return AgentRequest(**defaults) + + +def _make_pattern(): + return PatternBase(MockProcessor()) + + +class TestIsSubagent: + + def test_returns_true_when_correlation_id_set(self): + pattern = _make_pattern() + request = _make_request(correlation_id="corr-123") + assert pattern.is_subagent(request) is True + + def test_returns_false_when_correlation_id_empty(self): + pattern = _make_pattern() + request = _make_request(correlation_id="") + assert pattern.is_subagent(request) is False + + def test_returns_false_when_correlation_id_missing(self): + pattern = _make_pattern() + request = _make_request() + assert pattern.is_subagent(request) is False + + +class TestEmitSubagentCompletion: + + @pytest.mark.asyncio + async def test_calls_next_with_completion_request(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + parent_session_id="parent-sess", + subagent_goal="What is X?", + expected_siblings=4, + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "The answer is Y", + ) + + next_fn.assert_called_once() + completion_req = next_fn.call_args[0][0] + assert isinstance(completion_req, AgentRequest) + + @pytest.mark.asyncio + async def test_completion_has_correct_step_type(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + subagent_goal="What is X?", + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "answer text", + ) + + completion_req = next_fn.call_args[0][0] + assert len(completion_req.history) == 1 + step = completion_req.history[0] + assert step.step_type == "subagent-completion" + + @pytest.mark.asyncio + async def test_completion_carries_answer_in_observation(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + subagent_goal="What is X?", + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "The answer is Y", + ) + + completion_req = next_fn.call_args[0][0] + step = completion_req.history[0] + assert step.observation == "The answer is Y" + + @pytest.mark.asyncio + async def test_completion_preserves_correlation_fields(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + parent_session_id="parent-sess", + subagent_goal="What is X?", + expected_siblings=4, + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "answer", + ) + + completion_req = next_fn.call_args[0][0] + assert completion_req.correlation_id == "corr-123" + assert completion_req.parent_session_id == "parent-sess" + assert completion_req.subagent_goal == "What is X?" + assert completion_req.expected_siblings == 4 + + @pytest.mark.asyncio + async def test_completion_has_empty_pattern(self): + pattern = _make_pattern() + request = _make_request( + correlation_id="corr-123", + subagent_goal="goal", + ) + next_fn = AsyncMock() + + await pattern.emit_subagent_completion( + request, next_fn, "answer", + ) + + completion_req = next_fn.call_args[0][0] + assert completion_req.pattern == "" diff --git a/tests/unit/test_agent/test_provenance_triples.py b/tests/unit/test_agent/test_provenance_triples.py new file mode 100644 index 00000000..ed14d6ae --- /dev/null +++ b/tests/unit/test_agent/test_provenance_triples.py @@ -0,0 +1,226 @@ +""" +Unit tests for orchestrator provenance triple builders. +""" + +import pytest + +from trustgraph.provenance import ( + agent_decomposition_triples, + agent_finding_triples, + agent_plan_triples, + agent_step_result_triples, + agent_synthesis_triples, +) + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, RDFS_LABEL, + PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, + TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT, + TG_SUBAGENT_GOAL, TG_PLAN_STEP, +) + + +def _triple_set(triples): + """Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion.""" + result = set() + for t in triples: + s = t.s.iri + p = t.p.iri + o = t.o.iri if t.o.iri else t.o.value + result.add((s, p, o)) + return result + + +def _has_type(triples, uri, rdf_type): + """Check if a URI has a given rdf:type in the triples.""" + return (uri, RDF_TYPE, rdf_type) in _triple_set(triples) + + +def _get_values(triples, uri, predicate): + """Get all object values for a given subject + predicate.""" + ts = _triple_set(triples) + return [o for s, p, o in ts if s == uri and p == predicate] + + +class TestDecompositionTriples: + + def test_has_correct_types(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["goal-a", "goal-b"], + ) + assert _has_type(triples, "urn:decompose", PROV_ENTITY) + assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION) + + def test_not_answer_type(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["goal-a"], + ) + assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE) + + def test_links_to_session(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["goal-a"], + ) + ts = _triple_set(triples) + assert ("urn:decompose", PROV_WAS_GENERATED_BY, "urn:session") in ts + + def test_includes_goals(self): + goals = ["What is X?", "What is Y?", "What is Z?"] + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", goals, + ) + values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL) + assert set(values) == set(goals) + + def test_label_includes_count(self): + triples = agent_decomposition_triples( + "urn:decompose", "urn:session", ["a", "b", "c"], + ) + labels = _get_values(triples, "urn:decompose", RDFS_LABEL) + assert any("3" in label for label in labels) + + +class TestFindingTriples: + + def test_has_correct_types(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + ) + assert _has_type(triples, "urn:finding", PROV_ENTITY) + assert _has_type(triples, "urn:finding", TG_FINDING) + assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE) + + def test_links_to_decomposition(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + ) + ts = _triple_set(triples) + assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts + + def test_includes_goal(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "What is X?", + ) + values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL) + assert "What is X?" in values + + def test_includes_document_when_provided(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "goal", + document_id="urn:doc/1", + ) + values = _get_values(triples, "urn:finding", TG_DOCUMENT) + assert "urn:doc/1" in values + + def test_no_document_when_none(self): + triples = agent_finding_triples( + "urn:finding", "urn:decompose", "goal", + ) + values = _get_values(triples, "urn:finding", TG_DOCUMENT) + assert values == [] + + +class TestPlanTriples: + + def test_has_correct_types(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["step-a"], + ) + assert _has_type(triples, "urn:plan", PROV_ENTITY) + assert _has_type(triples, "urn:plan", TG_PLAN_TYPE) + + def test_not_answer_type(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["step-a"], + ) + assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE) + + def test_links_to_session(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["step-a"], + ) + ts = _triple_set(triples) + assert ("urn:plan", PROV_WAS_GENERATED_BY, "urn:session") in ts + + def test_includes_steps(self): + steps = ["Define X", "Research Y", "Analyse Z"] + triples = agent_plan_triples( + "urn:plan", "urn:session", steps, + ) + values = _get_values(triples, "urn:plan", TG_PLAN_STEP) + assert set(values) == set(steps) + + def test_label_includes_count(self): + triples = agent_plan_triples( + "urn:plan", "urn:session", ["a", "b"], + ) + labels = _get_values(triples, "urn:plan", RDFS_LABEL) + assert any("2" in label for label in labels) + + +class TestStepResultTriples: + + def test_has_correct_types(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + ) + assert _has_type(triples, "urn:step", PROV_ENTITY) + assert _has_type(triples, "urn:step", TG_STEP_RESULT) + assert _has_type(triples, "urn:step", TG_ANSWER_TYPE) + + def test_links_to_plan(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + ) + ts = _triple_set(triples) + assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts + + def test_includes_goal(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "Define X", + ) + values = _get_values(triples, "urn:step", TG_PLAN_STEP) + assert "Define X" in values + + def test_includes_document_when_provided(self): + triples = agent_step_result_triples( + "urn:step", "urn:plan", "goal", + document_id="urn:doc/step", + ) + values = _get_values(triples, "urn:step", TG_DOCUMENT) + assert "urn:doc/step" in values + + +class TestSynthesisTriples: + + def test_has_correct_types(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + ) + assert _has_type(triples, "urn:synthesis", PROV_ENTITY) + assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS) + assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE) + + def test_links_to_previous(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:last-finding", + ) + ts = _triple_set(triples) + assert ("urn:synthesis", PROV_WAS_DERIVED_FROM, + "urn:last-finding") in ts + + def test_includes_document_when_provided(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + document_id="urn:doc/synthesis", + ) + values = _get_values(triples, "urn:synthesis", TG_DOCUMENT) + assert "urn:doc/synthesis" in values + + def test_label_is_synthesis(self): + triples = agent_synthesis_triples( + "urn:synthesis", "urn:previous", + ) + labels = _get_values(triples, "urn:synthesis", RDFS_LABEL) + assert "Synthesis" in labels From 89e13a756a0a823de374d981914083dcfc7402a4 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 13:29:04 +0100 Subject: [PATCH 20/37] Minor agent-orchestrator updates (#746) Tidy agent-orchestrator logs Added CLI support for selecting the pattern... tg-invoke-agent -q "What is the document about?" -p supervisor -v tg-invoke-agent -q "What is the document about?" -p plan-then-execute -v tg-invoke-agent -q "What is the document about?" -p react -v Added new event types to tg-show-explain-trace --- .../trustgraph/api/explainability.py | 8 - trustgraph-cli/trustgraph/cli/invoke_agent.py | 12 +- .../trustgraph/cli/show_explain_trace.py | 182 +++++++++++++----- .../agent/orchestrator/aggregator.py | 4 +- .../agent/orchestrator/pattern_base.py | 2 +- .../agent/orchestrator/react_pattern.py | 4 - .../trustgraph/agent/orchestrator/service.py | 3 +- .../trustgraph/agent/react/service.py | 2 - .../trustgraph/agent/tool_filter.py | 10 +- 9 files changed, 152 insertions(+), 75 deletions(-) diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index 7b406a59..ee7fd05e 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -999,8 +999,6 @@ class ExplainabilityClient: trace = { "question": None, "steps": [], - "iterations": [], # Backwards compatibility for ReAct - "conclusion": None, } # Fetch question/session @@ -1015,11 +1013,6 @@ class ExplainabilityClient: is_first=True, max_depth=50, ) - # Backwards compat: populate iterations from steps - trace["iterations"] = [ - s for s in trace["steps"] if isinstance(s, Analysis) - ] - return trace def _follow_provenance_chain( @@ -1081,7 +1074,6 @@ class ExplainabilityClient: elif isinstance(entity, (Conclusion, Synthesis)): trace["steps"].append(entity) - trace["conclusion"] = entity def list_sessions( self, diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 2a1ba7c2..c82c78f6 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -267,7 +267,8 @@ def question_explainable( def question( url, question, flow_id, user, collection, - plan=None, state=None, group=None, verbose=False, streaming=True, + plan=None, state=None, group=None, pattern=None, + verbose=False, streaming=True, token=None, explainable=False, debug=False ): # Explainable mode uses the API to capture and process provenance events @@ -307,6 +308,8 @@ def question( request_params["state"] = state if group is not None: request_params["group"] = group + if pattern is not None: + request_params["pattern"] = pattern try: # Call agent @@ -430,6 +433,12 @@ def main(): help=f'Agent plan (default: unspecified)' ) + parser.add_argument( + '-p', '--pattern', + choices=['react', 'plan-then-execute', 'supervisor'], + help='Force execution pattern (default: auto-selected by meta-router)' + ) + parser.add_argument( '-s', '--state', help=f'Agent initial state (default: unspecified)' @@ -478,6 +487,7 @@ def main(): plan = args.plan, state = args.state, group = args.group, + pattern = args.pattern, verbose = args.verbose, streaming = not args.no_streaming, token = args.token, diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index a38476cb..c4da0d5a 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -27,6 +27,10 @@ from trustgraph.api import ( Synthesis, Analysis, Conclusion, + Decomposition, + Finding, + Plan, + StepResult, ) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') @@ -297,6 +301,23 @@ def print_docrag_text(trace, explain_client, api, user): print("No synthesis data found") +def _print_document_content(explain_client, api, user, document_uri, label="Answer"): + """Fetch and print document content, or fall back to URI.""" + if not document_uri: + return + content = "" + if api: + content = explain_client.fetch_document_content( + document_uri, api, user + ) + if content: + print(f"{label}:") + for line in content.split("\n"): + print(f" {line}") + else: + print(f"Document: {document_uri}") + + def print_agent_text(trace, explain_client, api, user): """Print Agent trace in text format.""" question = trace.get("question") @@ -310,82 +331,143 @@ def print_agent_text(trace, explain_client, api, user): print(f"Time: {question.timestamp}") print() - # Analysis steps - print("--- Analysis ---") - iterations = trace.get("iterations", []) - if iterations: - for i, analysis in enumerate(iterations, 1): - print(f"Analysis {i}:") - print(f" Thought: {analysis.thought or 'N/A'}") - print(f" Action: {analysis.action or 'N/A'}") + # Walk the steps list which contains all entity types + steps = trace.get("steps", []) + for step in steps: - if analysis.arguments: - # Try to pretty-print JSON arguments + if isinstance(step, Decomposition): + print("--- Decomposition ---") + print(f"Decomposed into {len(step.goals)} research threads:") + for i, goal in enumerate(step.goals): + print(f" {i}: {goal}") + print() + + elif isinstance(step, Finding): + print("--- Finding ---") + print(f"Goal: {step.goal}") + _print_document_content( + explain_client, api, user, step.document, "Result", + ) + print() + + elif isinstance(step, Plan): + print("--- Plan ---") + print(f"Plan with {len(step.steps)} steps:") + for i, s in enumerate(step.steps): + print(f" {i}: {s}") + print() + + elif isinstance(step, StepResult): + print("--- Step Result ---") + print(f"Step: {step.step}") + _print_document_content( + explain_client, api, user, step.document, "Result", + ) + print() + + elif isinstance(step, Analysis): + print("--- Analysis ---") + print(f" Action: {step.action or 'N/A'}") + + if step.arguments: try: - args_obj = json.loads(analysis.arguments) + args_obj = json.loads(step.arguments) args_str = json.dumps(args_obj, indent=4) print(f" Arguments:") for line in args_str.split('\n'): print(f" {line}") except Exception: - print(f" Arguments: {analysis.arguments}") - else: - print(f" Arguments: N/A") + print(f" Arguments: {step.arguments}") - obs = analysis.observation or 'N/A' + obs = step.observation or 'N/A' if obs and len(obs) > 200: obs = obs[:200] + "... [truncated]" print(f" Observation: {obs}") print() - else: - print("No analysis steps recorded") - print() - # Conclusion - print("--- Conclusion ---") - conclusion = trace.get("conclusion") - if conclusion: - content = "" - if conclusion.document and api: - content = explain_client.fetch_document_content( - conclusion.document, api, user + elif isinstance(step, Synthesis): + print("--- Synthesis ---") + _print_document_content( + explain_client, api, user, step.document, "Answer", ) - if content: - print("Answer:") - for line in content.split("\n"): - print(f" {line}") - elif conclusion.document: - print(f"Document: {conclusion.document}") - else: - print("No conclusion recorded") - else: - print("No conclusion recorded") + print() + + elif isinstance(step, Conclusion): + print("--- Conclusion ---") + _print_document_content( + explain_client, api, user, step.document, "Answer", + ) + print() + + if not steps: + print("No trace steps recorded") + print() def trace_to_dict(trace, trace_type): """Convert trace entities to JSON-serializable dict.""" if trace_type == "agent": question = trace.get("question") + + def _step_to_dict(step): + if isinstance(step, Decomposition): + return { + "type": "decomposition", + "id": step.uri, + "goals": step.goals, + } + elif isinstance(step, Finding): + return { + "type": "finding", + "id": step.uri, + "goal": step.goal, + "document": step.document, + } + elif isinstance(step, Plan): + return { + "type": "plan", + "id": step.uri, + "steps": step.steps, + } + elif isinstance(step, StepResult): + return { + "type": "step-result", + "id": step.uri, + "step": step.step, + "document": step.document, + } + elif isinstance(step, Analysis): + return { + "type": "analysis", + "id": step.uri, + "action": step.action, + "arguments": step.arguments, + "thought": step.thought, + "observation": step.observation, + } + elif isinstance(step, Synthesis): + return { + "type": "synthesis", + "id": step.uri, + "document": step.document, + } + elif isinstance(step, Conclusion): + return { + "type": "conclusion", + "id": step.uri, + "document": step.document, + } + return {"type": step.entity_type, "id": step.uri} + + steps = trace.get("steps", []) + return { "type": "agent", "session_id": question.uri if question else None, "question": question.query if question else None, "time": question.timestamp if question else None, - "iterations": [ - { - "id": a.uri, - "thought": a.thought, - "action": a.action, - "arguments": a.arguments, - "observation": a.observation, - } - for a in trace.get("iterations", []) - ], - "conclusion": { - "id": trace["conclusion"].uri, - "document": trace["conclusion"].document, - } if trace.get("conclusion") else None, + "steps": [_step_to_dict(s) for s in steps], } elif trace_type == "docrag": question = trace.get("question") diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py index 9187f21e..cc5eb85c 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/aggregator.py @@ -57,7 +57,7 @@ class Aggregator: "request_template": request_template, "created_at": time.time(), } - logger.info( + logger.debug( f"Aggregator: registered fan-out {correlation_id}, " f"expecting {expected_siblings} subagents" ) @@ -82,7 +82,7 @@ class Aggregator: completed = len(entry["results"]) expected = entry["expected"] - logger.info( + logger.debug( f"Aggregator: {correlation_id} — " f"{completed}/{expected} subagents complete" ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index ddc4aed9..4faa7ce6 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -106,7 +106,7 @@ class PatternBase: ) await next(completion_request) - logger.info( + logger.debug( f"Subagent completion emitted for " f"correlation={request.correlation_id}, " f"goal={getattr(request, 'subagent_goal', '')}" diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index a03dc194..32261809 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -60,10 +60,6 @@ class ReactPattern(PatternBase): filtered_tools = self.filter_tools( self.processor.agent.tools, request, ) - logger.info( - f"Filtered from {len(self.processor.agent.tools)} " - f"to {len(filtered_tools)} available tools" - ) # Create temporary agent with filtered tools and optional framing additional_context = self.processor.agent.additional_context diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 9ca3fe59..ed4c3983 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -414,7 +414,6 @@ class Processor(AgentService): self.meta_router = MetaRouter(config=config) logger.info(f"Loaded {len(tools)} tools") - logger.info("Tool configuration reloaded.") except Exception as e: logger.error( @@ -436,7 +435,7 @@ class Processor(AgentService): answer_text = step.observation break - logger.info( + logger.debug( f"Received subagent completion: " f"correlation={correlation_id}, goal={subagent_goal}" ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 1bca9627..6c06f71a 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -550,8 +550,6 @@ class Processor(AgentService): current_state=getattr(request, 'state', None) ) - logger.info(f"Filtered from {len(self.agent.tools)} to {len(filtered_tools)} available tools") - # Create temporary agent with filtered tools temp_agent = AgentManager( tools=filtered_tools, diff --git a/trustgraph-flow/trustgraph/agent/tool_filter.py b/trustgraph-flow/trustgraph/agent/tool_filter.py index d1bac3e4..8a50bd41 100644 --- a/trustgraph-flow/trustgraph/agent/tool_filter.py +++ b/trustgraph-flow/trustgraph/agent/tool_filter.py @@ -34,17 +34,17 @@ def filter_tools_by_group_and_state( if current_state is None or current_state == "": current_state = "undefined" - logger.info(f"Filtering tools with groups={requested_groups}, state={current_state}") - + logger.debug(f"Filtering tools with groups={requested_groups}, state={current_state}") + filtered_tools = {} - + for tool_name, tool in tools.items(): if _is_tool_available(tool, requested_groups, current_state): filtered_tools[tool_name] = tool else: logger.debug(f"Tool {tool_name} filtered out") - - logger.info(f"Filtered {len(tools)} tools to {len(filtered_tools)} available tools") + + logger.debug(f"Filtered {len(tools)} tools to {len(filtered_tools)} available tools") return filtered_tools From 153ae9ad3001c8bfe74528f48b4c6581023282c9 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 31 Mar 2026 17:51:22 +0100 Subject: [PATCH 21/37] Split Analysis into Analysis+ToolUse and Observation, add message_id (#747) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor agent provenance so that the decision (thought + tool selection) and the result (observation) are separate DAG entities: Question ← Analysis+ToolUse ← Observation ← ... ← Conclusion Analysis gains tg:ToolUse as a mixin RDF type and is emitted before tool execution via an on_action callback in react(). This ensures sub-traces (e.g. GraphRAG) appear after their parent Analysis in the streaming event order. Observation becomes a standalone prov:Entity with tg:Observation type, emitted after tool execution. The linear DAG chain runs through Observation — subsequent iterations and the Conclusion derive from it, not from the Analysis. message_id is populated on streaming AgentResponse for thought and observation chunks, using the provenance URI of the entity being built. This lets clients group streamed chunks by entity. Wire changes: - provenance/agent.py: Add ToolUse type, new agent_observation_triples(), remove observation from iteration - agent_manager.py: Add on_action callback between reason() and tool execution - orchestrator/pattern_base.py: Split emit, wire message_id, chain through observation URIs - orchestrator/react_pattern.py: Emit Analysis via on_action before tool runs - agent/react/service.py: Same for non-orchestrator path - api/explainability.py: New Observation class, updated dispatch and chain walker - api/types.py: Add message_id to AgentThought/AgentObservation - cli: Render Observation separately, [analysis: tool] labels --- .../test_agent_manager_integration.py | 16 +- .../test_agent_service_non_streaming.py | 40 +++- .../test_agent/test_explainability_parsing.py | 7 + .../test_agent/test_provenance_triples.py | 6 +- .../test_provenance/test_agent_provenance.py | 132 +++++++------ .../test_provenance/test_explainability.py | 25 ++- tests/unit/test_provenance/test_triples.py | 12 +- trustgraph-base/trustgraph/api/__init__.py | 2 + .../trustgraph/api/explainability.py | 127 ++++++++----- .../trustgraph/api/socket_client.py | 6 +- trustgraph-base/trustgraph/api/types.py | 4 + .../trustgraph/base/graph_rag_client.py | 2 + .../trustgraph/provenance/__init__.py | 4 + .../trustgraph/provenance/agent.py | 80 +++++--- .../trustgraph/provenance/namespaces.py | 1 + .../trustgraph/provenance/triples.py | 38 +++- .../trustgraph/schema/services/retrieval.py | 1 + trustgraph-cli/trustgraph/cli/invoke_agent.py | 15 +- .../trustgraph/cli/show_explain_trace.py | 18 +- .../agent/orchestrator/pattern_base.py | 178 ++++++++++-------- .../agent/orchestrator/plan_pattern.py | 37 +++- .../agent/orchestrator/react_pattern.py | 40 +++- .../agent/orchestrator/supervisor_pattern.py | 10 +- .../trustgraph/agent/react/agent_manager.py | 7 +- .../trustgraph/agent/react/service.py | 163 +++++++++------- .../trustgraph/agent/react/tools.py | 32 +++- .../retrieval/graph_rag/graph_rag.py | 6 +- .../trustgraph/retrieval/graph_rag/rag.py | 2 + 28 files changed, 661 insertions(+), 350 deletions(-) diff --git a/tests/integration/test_agent_manager_integration.py b/tests/integration/test_agent_manager_integration.py index 5db95638..652894a2 100644 --- a/tests/integration/test_agent_manager_integration.py +++ b/tests/integration/test_agent_manager_integration.py @@ -9,7 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing. import pytest import json -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, ANY, patch from trustgraph.agent.react.agent_manager import AgentManager from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl @@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn # Verify tool was executed graph_rag_client = mock_flow_context("graph-rag-request") - graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default") + graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default", explain_callback=ANY, parent_uri=ANY) @pytest.mark.asyncio async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context): @@ -272,7 +272,7 @@ Args: {{ # Verify correct service was called if tool_name == "knowledge_query": - mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default") + mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default", explain_callback=ANY, parent_uri=ANY) elif tool_name == "text_completion": mock_flow_context("prompt-request").question.assert_called() @@ -726,7 +726,7 @@ Final Answer: { # Assert graph_rag_client = mock_flow_context("graph-rag-request") - graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default") + graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default", explain_callback=ANY, parent_uri=ANY) @pytest.mark.asyncio async def test_knowledge_query_with_custom_collection(self, mock_flow_context): @@ -739,7 +739,7 @@ Final Answer: { # Assert graph_rag_client = mock_flow_context("graph-rag-request") - graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection") + graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection", explain_callback=ANY, parent_uri=ANY) @pytest.mark.asyncio async def test_knowledge_query_with_none_collection(self, mock_flow_context): @@ -752,7 +752,7 @@ Final Answer: { # Assert graph_rag_client = mock_flow_context("graph-rag-request") - graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default") + graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default", explain_callback=ANY, parent_uri=ANY) @pytest.mark.asyncio async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context): @@ -810,7 +810,7 @@ Args: { # Verify the custom collection was used graph_rag_client = mock_flow_context("graph-rag-request") - graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers") + graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers", explain_callback=ANY, parent_uri=ANY) @pytest.mark.asyncio async def test_knowledge_query_multiple_collections(self, mock_flow_context): @@ -840,4 +840,4 @@ Args: { # Verify correct collection was used graph_rag_client = mock_flow_context("graph-rag-request") - graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection) + graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection, explain_callback=ANY, parent_uri=ANY) diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index ff630325..0b9b283a 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -39,7 +39,7 @@ class TestAgentServiceNonStreaming: mock_agent_manager_class.return_value = mock_agent_instance # Mock react to call think and observe callbacks - async def mock_react(question, history, think, observe, answer, context, streaming): + async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): await think("I need to solve this.", is_final=True) await observe("The answer is 4.", is_final=True) return Final(thought="Final answer", final="4") @@ -76,11 +76,22 @@ class TestAgentServiceNonStreaming: # Execute await processor.on_request(msg, consumer, flow) - # Verify: should have 3 responses (thought, observation, answer) - assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}" + # Filter out explain events — those are always sent now + content_responses = [ + r for r in sent_responses if r.chunk_type != "explain" + ] + explain_responses = [ + r for r in sent_responses if r.chunk_type == "explain" + ] + + # Should have explain events for session, iteration, observation, and final + assert len(explain_responses) >= 1, "Expected at least 1 explain event" + + # Should have 3 content responses (thought, observation, answer) + assert len(content_responses) == 3, f"Expected 3 content responses, got {len(content_responses)}" # Check thought message - thought_response = sent_responses[0] + thought_response = content_responses[0] assert isinstance(thought_response, AgentResponse) assert thought_response.chunk_type == "thought" assert thought_response.content == "I need to solve this." @@ -88,7 +99,7 @@ class TestAgentServiceNonStreaming: assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False" # Check observation message - observation_response = sent_responses[1] + observation_response = content_responses[1] assert isinstance(observation_response, AgentResponse) assert observation_response.chunk_type == "observation" assert observation_response.content == "The answer is 4." @@ -120,7 +131,7 @@ class TestAgentServiceNonStreaming: mock_agent_manager_class.return_value = mock_agent_instance # Mock react to return Final directly - async def mock_react(question, history, think, observe, answer, context, streaming): + async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None): return Final(thought="Final answer", final="4") mock_agent_instance.react = mock_react @@ -155,11 +166,22 @@ class TestAgentServiceNonStreaming: # Execute await processor.on_request(msg, consumer, flow) - # Verify: should have 1 response (final answer) - assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}" + # Filter out explain events — those are always sent now + content_responses = [ + r for r in sent_responses if r.chunk_type != "explain" + ] + explain_responses = [ + r for r in sent_responses if r.chunk_type == "explain" + ] + + # Should have explain events for session and final + assert len(explain_responses) >= 1, "Expected at least 1 explain event" + + # Should have 1 content response (final answer) + assert len(content_responses) == 1, f"Expected 1 content response, got {len(content_responses)}" # Check final answer message - answer_response = sent_responses[0] + answer_response = content_responses[0] assert isinstance(answer_response, AgentResponse) assert answer_response.chunk_type == "answer" assert answer_response.content == "4" diff --git a/tests/unit/test_agent/test_explainability_parsing.py b/tests/unit/test_agent/test_explainability_parsing.py index e09a7f1f..7035318d 100644 --- a/tests/unit/test_agent/test_explainability_parsing.py +++ b/tests/unit/test_agent/test_explainability_parsing.py @@ -13,6 +13,7 @@ from trustgraph.api.explainability import ( StepResult, Synthesis, Analysis, + Observation, Conclusion, TG_DECOMPOSITION, TG_FINDING, @@ -20,6 +21,7 @@ from trustgraph.api.explainability import ( TG_STEP_RESULT, TG_SYNTHESIS, TG_ANSWER_TYPE, + TG_OBSERVATION_TYPE, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, @@ -74,6 +76,11 @@ class TestFromTriplesDispatch: entity = ExplainEntity.from_triples("urn:a", triples) assert isinstance(entity, Analysis) + def test_dispatches_observation(self): + triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE]) + entity = ExplainEntity.from_triples("urn:o", triples) + assert isinstance(entity, Observation) + def test_dispatches_conclusion_unchanged(self): triples = _make_triples("urn:c", [PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE]) diff --git a/tests/unit/test_agent/test_provenance_triples.py b/tests/unit/test_agent/test_provenance_triples.py index ed14d6ae..c83f4b08 100644 --- a/tests/unit/test_agent/test_provenance_triples.py +++ b/tests/unit/test_agent/test_provenance_triples.py @@ -14,7 +14,7 @@ from trustgraph.provenance import ( from trustgraph.provenance.namespaces import ( RDF_TYPE, RDFS_LABEL, - PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_ENTITY, PROV_WAS_DERIVED_FROM, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT, TG_SUBAGENT_GOAL, TG_PLAN_STEP, @@ -63,7 +63,7 @@ class TestDecompositionTriples: "urn:decompose", "urn:session", ["goal-a"], ) ts = _triple_set(triples) - assert ("urn:decompose", PROV_WAS_GENERATED_BY, "urn:session") in ts + assert ("urn:decompose", PROV_WAS_DERIVED_FROM, "urn:session") in ts def test_includes_goals(self): goals = ["What is X?", "What is Y?", "What is Z?"] @@ -141,7 +141,7 @@ class TestPlanTriples: "urn:plan", "urn:session", ["step-a"], ) ts = _triple_set(triples) - assert ("urn:plan", PROV_WAS_GENERATED_BY, "urn:session") in ts + assert ("urn:plan", PROV_WAS_DERIVED_FROM, "urn:session") in ts def test_includes_steps(self): steps = ["Define X", "Research Y", "Analyse Z"] diff --git a/tests/unit/test_provenance/test_agent_provenance.py b/tests/unit/test_provenance/test_agent_provenance.py index 4efe24c7..d3f0ef8c 100644 --- a/tests/unit/test_provenance/test_agent_provenance.py +++ b/tests/unit/test_provenance/test_agent_provenance.py @@ -10,16 +10,18 @@ from trustgraph.schema import Triple, Term, IRI, LITERAL from trustgraph.provenance.agent import ( agent_session_triples, agent_iteration_triples, + agent_observation_triples, agent_final_triples, ) from trustgraph.provenance.namespaces import ( RDF_TYPE, RDFS_LABEL, - PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, - PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME, - TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + PROV_ENTITY, PROV_WAS_DERIVED_FROM, + PROV_STARTED_AT_TIME, + TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + TG_TOOL_USE, TG_AGENT_QUESTION, ) @@ -63,7 +65,7 @@ class TestAgentSessionTriples: triples = agent_session_triples( self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z" ) - assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY) + assert has_type(triples, self.SESSION_URI, PROV_ENTITY) assert has_type(triples, self.SESSION_URI, TG_QUESTION) assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION) @@ -121,19 +123,17 @@ class TestAgentIterationTriples: ) assert has_type(triples, self.ITER_URI, PROV_ENTITY) assert has_type(triples, self.ITER_URI, TG_ANALYSIS) + assert has_type(triples, self.ITER_URI, TG_TOOL_USE) - def test_first_iteration_generated_by_question(self): - """First iteration uses wasGeneratedBy to link to question activity.""" + def test_first_iteration_derived_from_question(self): + """First iteration uses wasDerivedFrom to link to question entity.""" triples = agent_iteration_triples( self.ITER_URI, question_uri=self.SESSION_URI, action="search", ) - gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI) - assert gen is not None - assert gen.o.iri == self.SESSION_URI - # Should NOT have wasDerivedFrom derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) - assert derived is None + assert derived is not None + assert derived.o.iri == self.SESSION_URI def test_subsequent_iteration_derived_from_previous(self): """Subsequent iterations use wasDerivedFrom to link to previous iteration.""" @@ -144,9 +144,6 @@ class TestAgentIterationTriples: derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) assert derived is not None assert derived.o.iri == self.PREV_URI - # Should NOT have wasGeneratedBy - gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI) - assert gen is None def test_iteration_label_includes_action(self): triples = agent_iteration_triples( @@ -174,40 +171,24 @@ class TestAgentIterationTriples: # Thought has correct types assert has_type(triples, thought_uri, TG_REFLECTION_TYPE) assert has_type(triples, thought_uri, TG_THOUGHT_TYPE) - # Thought was generated by iteration - gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri) - assert gen is not None - assert gen.o.iri == self.ITER_URI + # Thought was derived from iteration + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, thought_uri) + assert derived is not None + assert derived.o.iri == self.ITER_URI # Thought has document reference doc = find_triple(triples, TG_DOCUMENT, thought_uri) assert doc is not None assert doc.o.iri == thought_doc - def test_iteration_observation_sub_entity(self): - """Observation is a sub-entity with Reflection and Observation types.""" - obs_uri = "urn:trustgraph:agent:test-session/i1/observation" - obs_doc = "urn:doc:obs-1" + def test_iteration_no_observation_sub_entity(self): + """Iteration no longer embeds observation — it's a separate entity.""" triples = agent_iteration_triples( self.ITER_URI, question_uri=self.SESSION_URI, action="search", - observation_uri=obs_uri, - observation_document_id=obs_doc, ) - # Iteration links to observation sub-entity - obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI) - assert obs_link is not None - assert obs_link.o.iri == obs_uri - # Observation has correct types - assert has_type(triples, obs_uri, TG_REFLECTION_TYPE) - assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE) - # Observation was generated by iteration - gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri) - assert gen is not None - assert gen.o.iri == self.ITER_URI - # Observation has document reference - doc = find_triple(triples, TG_DOCUMENT, obs_uri) - assert doc is not None - assert doc.o.iri == obs_doc + # No TG_OBSERVATION predicate on the iteration + for t in triples: + assert "observation" not in t.p.iri.lower() or "Observation" not in t.p.iri def test_iteration_action_recorded(self): triples = agent_iteration_triples( @@ -240,19 +221,17 @@ class TestAgentIterationTriples: parsed = json.loads(arguments.o.value) assert parsed == {} - def test_iteration_no_thought_or_observation(self): - """Minimal iteration with just action — no thought or observation triples.""" + def test_iteration_no_thought(self): + """Minimal iteration with just action — no thought triples.""" triples = agent_iteration_triples( self.ITER_URI, question_uri=self.SESSION_URI, action="noop", ) thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) - obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI) assert thought is None - assert obs is None def test_iteration_chaining(self): - """First iteration uses wasGeneratedBy, second uses wasDerivedFrom.""" + """Both first and second iterations use wasDerivedFrom.""" iter1_uri = "urn:trustgraph:agent:sess/i1" iter2_uri = "urn:trustgraph:agent:sess/i2" @@ -263,13 +242,62 @@ class TestAgentIterationTriples: iter2_uri, previous_uri=iter1_uri, action="step2", ) - gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri) - assert gen1.o.iri == self.SESSION_URI + derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri) + assert derived1.o.iri == self.SESSION_URI derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri) assert derived2.o.iri == iter1_uri +# --------------------------------------------------------------------------- +# agent_observation_triples +# --------------------------------------------------------------------------- + +class TestAgentObservationTriples: + + OBS_URI = "urn:trustgraph:agent:test-session/i1/observation" + ITER_URI = "urn:trustgraph:agent:test-session/i1" + + def test_observation_types(self): + triples = agent_observation_triples( + self.OBS_URI, self.ITER_URI, + ) + assert has_type(triples, self.OBS_URI, PROV_ENTITY) + assert has_type(triples, self.OBS_URI, TG_OBSERVATION_TYPE) + + def test_observation_derived_from_iteration(self): + triples = agent_observation_triples( + self.OBS_URI, self.ITER_URI, + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.OBS_URI) + assert derived is not None + assert derived.o.iri == self.ITER_URI + + def test_observation_label(self): + triples = agent_observation_triples( + self.OBS_URI, self.ITER_URI, + ) + label = find_triple(triples, RDFS_LABEL, self.OBS_URI) + assert label is not None + assert label.o.value == "Observation" + + def test_observation_document(self): + doc_id = "urn:doc:obs-1" + triples = agent_observation_triples( + self.OBS_URI, self.ITER_URI, document_id=doc_id, + ) + doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI) + assert doc is not None + assert doc.o.iri == doc_id + + def test_observation_no_document(self): + triples = agent_observation_triples( + self.OBS_URI, self.ITER_URI, + ) + doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI) + assert doc is None + + # --------------------------------------------------------------------------- # agent_final_triples # --------------------------------------------------------------------------- @@ -296,19 +324,15 @@ class TestAgentFinalTriples: derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) assert derived is not None assert derived.o.iri == self.PREV_URI - gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI) - assert gen is None - def test_final_generated_by_question_when_no_iterations(self): - """When agent answers immediately, final uses wasGeneratedBy.""" + def test_final_derived_from_question_when_no_iterations(self): + """When agent answers immediately, final uses wasDerivedFrom to question.""" triples = agent_final_triples( self.FINAL_URI, question_uri=self.SESSION_URI, ) - gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI) - assert gen is not None - assert gen.o.iri == self.SESSION_URI derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) - assert derived is None + assert derived is not None + assert derived.o.iri == self.SESSION_URI def test_final_label(self): triples = agent_final_triples( diff --git a/tests/unit/test_provenance/test_explainability.py b/tests/unit/test_provenance/test_explainability.py index 62498c61..e2c7fcd1 100644 --- a/tests/unit/test_provenance/test_explainability.py +++ b/tests/unit/test_provenance/test_explainability.py @@ -16,6 +16,7 @@ from trustgraph.api.explainability import ( Synthesis, Reflection, Analysis, + Observation, Conclusion, parse_edge_selection_triples, extract_term_value, @@ -23,12 +24,12 @@ from trustgraph.api.explainability import ( ExplainabilityClient, TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY, - TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANALYSIS, TG_CONCLUSION, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, - PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, + PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, RDF_TYPE, RDFS_LABEL, ) @@ -180,14 +181,30 @@ class TestExplainEntityFromTriples: ("urn:ana:1", TG_ACTION, "graph-rag-query"), ("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'), ("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"), - ("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"), ] entity = ExplainEntity.from_triples("urn:ana:1", triples) assert isinstance(entity, Analysis) assert entity.action == "graph-rag-query" assert entity.arguments == '{"query": "test"}' assert entity.thought == "urn:ref:thought-1" - assert entity.observation == "urn:ref:obs-1" + + def test_observation(self): + triples = [ + ("urn:obs:1", RDF_TYPE, TG_OBSERVATION_TYPE), + ("urn:obs:1", TG_DOCUMENT, "urn:doc:obs-content"), + ] + entity = ExplainEntity.from_triples("urn:obs:1", triples) + assert isinstance(entity, Observation) + assert entity.document == "urn:doc:obs-content" + assert entity.entity_type == "observation" + + def test_observation_no_document(self): + triples = [ + ("urn:obs:2", RDF_TYPE, TG_OBSERVATION_TYPE), + ] + entity = ExplainEntity.from_triples("urn:obs:2", triples) + assert isinstance(entity, Observation) + assert entity.document == "" def test_conclusion_with_document(self): triples = [ diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py index 9aff7e4b..792db028 100644 --- a/tests/unit/test_provenance/test_triples.py +++ b/tests/unit/test_provenance/test_triples.py @@ -500,7 +500,7 @@ class TestQuestionTriples: def test_question_types(self): triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z") - assert has_type(triples, self.Q_URI, PROV_ACTIVITY) + assert has_type(triples, self.Q_URI, PROV_ENTITY) assert has_type(triples, self.Q_URI, TG_QUESTION) assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION) @@ -543,11 +543,11 @@ class TestGroundingTriples: assert has_type(triples, self.GND_URI, PROV_ENTITY) assert has_type(triples, self.GND_URI, TG_GROUNDING) - def test_grounding_generated_by_question(self): + def test_grounding_derived_from_question(self): triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"]) - gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI) - assert gen is not None - assert gen.o.iri == self.Q_URI + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.GND_URI) + assert derived is not None + assert derived.o.iri == self.Q_URI def test_grounding_concepts(self): triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"]) @@ -730,7 +730,7 @@ class TestDocRagQuestionTriples: def test_docrag_question_types(self): triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z") - assert has_type(triples, self.Q_URI, PROV_ACTIVITY) + assert has_type(triples, self.Q_URI, PROV_ENTITY) assert has_type(triples, self.Q_URI, TG_QUESTION) assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION) diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index e956db65..8b703dc7 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -81,6 +81,7 @@ from .explainability import ( Synthesis, Reflection, Analysis, + Observation, Conclusion, Decomposition, Finding, @@ -164,6 +165,7 @@ __all__ = [ "Focus", "Synthesis", "Analysis", + "Observation", "Conclusion", "EdgeSelection", "wire_triples_to_tuples", diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index ee7fd05e..fa6c4a0c 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -40,6 +40,7 @@ TG_ANSWER_TYPE = TG + "Answer" TG_REFLECTION_TYPE = TG + "Reflection" TG_THOUGHT_TYPE = TG + "Thought" TG_OBSERVATION_TYPE = TG + "Observation" +TG_TOOL_USE = TG + "ToolUse" TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" TG_AGENT_QUESTION = TG + "AgentQuestion" @@ -58,7 +59,6 @@ TG_PLAN_STEP = TG + "planStep" PROV = "http://www.w3.org/ns/prov#" PROV_STARTED_AT_TIME = PROV + "startedAtTime" PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy" RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type" RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" @@ -102,6 +102,8 @@ class ExplainEntity: return StepResult.from_triples(uri, triples) elif TG_SYNTHESIS in types: return Synthesis.from_triples(uri, triples) + elif TG_OBSERVATION_TYPE in types and TG_REFLECTION_TYPE not in types: + return Observation.from_triples(uri, triples) elif TG_REFLECTION_TYPE in types: return Reflection.from_triples(uri, triples) elif TG_ANALYSIS in types: @@ -279,18 +281,16 @@ class Reflection(ExplainEntity): @dataclass class Analysis(ExplainEntity): - """Analysis entity - one think/act/observe cycle (Agent only).""" + """Analysis+ToolUse entity - decision + tool call (Agent only).""" action: str = "" arguments: str = "" # JSON string thought: str = "" - observation: str = "" @classmethod def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis": action = "" arguments = "" thought = "" - observation = "" for s, p, o in triples: if p == TG_ACTION: @@ -299,8 +299,6 @@ class Analysis(ExplainEntity): arguments = o elif p == TG_THOUGHT: thought = o - elif p == TG_OBSERVATION: - observation = o return cls( uri=uri, @@ -308,7 +306,26 @@ class Analysis(ExplainEntity): action=action, arguments=arguments, thought=thought, - observation=observation + ) + + +@dataclass +class Observation(ExplainEntity): + """Observation entity - standalone tool result (Agent only).""" + document: str = "" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Observation": + document = "" + + for s, p, o in triples: + if p == TG_DOCUMENT: + document = o + + return cls( + uri=uri, + entity_type="observation", + document=document, ) @@ -757,9 +774,9 @@ class ExplainabilityClient: return trace trace["question"] = question - # Find grounding: ?grounding prov:wasGeneratedBy question_uri + # Find grounding: ?grounding prov:wasDerivedFrom question_uri grounding_triples = self.flow.triples_query( - p=PROV_WAS_GENERATED_BY, + p=PROV_WAS_DERIVED_FROM, o=question_uri, g=graph, user=user, @@ -894,9 +911,9 @@ class ExplainabilityClient: return trace trace["question"] = question - # Find grounding: ?grounding prov:wasGeneratedBy question_uri + # Find grounding: ?grounding prov:wasDerivedFrom question_uri grounding_triples = self.flow.triples_query( - p=PROV_WAS_GENERATED_BY, + p=PROV_WAS_DERIVED_FROM, o=question_uri, g=graph, user=user, @@ -1010,41 +1027,26 @@ class ExplainabilityClient: # Follow the provenance chain from the question self._follow_provenance_chain( session_uri, trace, graph, user, collection, - is_first=True, max_depth=50, + max_depth=50, ) return trace def _follow_provenance_chain( self, current_uri, trace, graph, user, collection, - is_first=False, max_depth=50, + max_depth=50, ): """Recursively follow the provenance chain, handling branches.""" if max_depth <= 0: return # Find entities derived from current_uri - if is_first: - derived_triples = self.flow.triples_query( - p=PROV_WAS_GENERATED_BY, - o=current_uri, - g=graph, user=user, collection=collection, - limit=20 - ) - if not derived_triples: - derived_triples = self.flow.triples_query( - p=PROV_WAS_DERIVED_FROM, - o=current_uri, - g=graph, user=user, collection=collection, - limit=20 - ) - else: - derived_triples = self.flow.triples_query( - p=PROV_WAS_DERIVED_FROM, - o=current_uri, - g=graph, user=user, collection=collection, - limit=20 - ) + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=current_uri, + g=graph, user=user, collection=collection, + limit=20 + ) if not derived_triples: return @@ -1062,8 +1064,8 @@ class ExplainabilityClient: if entity is None: continue - if isinstance(entity, (Analysis, Decomposition, Finding, - Plan, StepResult)): + if isinstance(entity, (Analysis, Observation, Decomposition, + Finding, Plan, StepResult)): trace["steps"].append(entity) # Continue following from this entity @@ -1072,6 +1074,27 @@ class ExplainabilityClient: max_depth=max_depth - 1, ) + elif isinstance(entity, Question): + # Sub-trace: a RAG session linked to this agent step. + # Fetch the full sub-trace and embed it. + if entity.question_type == "graph-rag": + sub_trace = self.fetch_graphrag_trace( + derived_uri, graph, user, collection, + ) + elif entity.question_type == "document-rag": + sub_trace = self.fetch_docrag_trace( + derived_uri, graph, user, collection, + ) + else: + sub_trace = None + + if sub_trace: + trace["steps"].append({ + "type": "sub-trace", + "question": entity, + "trace": sub_trace, + }) + elif isinstance(entity, (Conclusion, Synthesis)): trace["steps"].append(entity) @@ -1114,10 +1137,25 @@ class ExplainabilityClient: if isinstance(entity, Question): questions.append(entity) - # Sort by timestamp (newest first) - questions.sort(key=lambda q: q.timestamp or "", reverse=True) + # Filter out sub-traces: sessions that have a wasDerivedFrom link + # (they are child sessions linked to a parent agent iteration) + top_level = [] + for q in questions: + parent_triples = self.flow.triples_query( + s=q.uri, + p=PROV_WAS_DERIVED_FROM, + g=graph, + user=user, + collection=collection, + limit=1 + ) + if not parent_triples: + top_level.append(q) - return questions + # Sort by timestamp (newest first) + top_level.sort(key=lambda q: q.timestamp or "", reverse=True) + + return top_level def detect_session_type( self, @@ -1159,18 +1197,9 @@ class ExplainabilityClient: limit=5 ) - generated_triples = self.flow.triples_query( - p=PROV_WAS_GENERATED_BY, - o=session_uri, - g=graph, - user=user, - collection=collection, - limit=5 - ) - all_child_uris = [ extract_term_value(t.get("s", {})) - for t in (derived_triples + generated_triples) + for t in derived_triples ] for child_uri in all_child_uris: diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index e5f63c79..3b463762 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -384,12 +384,14 @@ class SocketClient: if chunk_type == "thought": return AgentThought( content=resp.get("content", ""), - end_of_message=resp.get("end_of_message", False) + end_of_message=resp.get("end_of_message", False), + message_id=resp.get("message_id", ""), ) elif chunk_type == "observation": return AgentObservation( content=resp.get("content", ""), - end_of_message=resp.get("end_of_message", False) + end_of_message=resp.get("end_of_message", False), + message_id=resp.get("message_id", ""), ) elif chunk_type == "answer" or chunk_type == "final-answer": return AgentAnswer( diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index d39310f2..3e3f1520 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -150,8 +150,10 @@ class AgentThought(StreamingChunk): content: Agent's thought text end_of_message: True if this completes the current thought chunk_type: Always "thought" + message_id: Provenance URI of the entity being built """ chunk_type: str = "thought" + message_id: str = "" @dataclasses.dataclass class AgentObservation(StreamingChunk): @@ -165,8 +167,10 @@ class AgentObservation(StreamingChunk): content: Observation text describing tool results end_of_message: True if this completes the current observation chunk_type: Always "observation" + message_id: Provenance URI of the entity being built """ chunk_type: str = "observation" + message_id: str = "" @dataclasses.dataclass class AgentAnswer(StreamingChunk): diff --git a/trustgraph-base/trustgraph/base/graph_rag_client.py b/trustgraph-base/trustgraph/base/graph_rag_client.py index 66dbad1e..32007943 100644 --- a/trustgraph-base/trustgraph/base/graph_rag_client.py +++ b/trustgraph-base/trustgraph/base/graph_rag_client.py @@ -5,6 +5,7 @@ from .. schema import GraphRagQuery, GraphRagResponse class GraphRagClient(RequestResponse): async def rag(self, query, user="trustgraph", collection="default", chunk_callback=None, explain_callback=None, + parent_uri="", timeout=600): """ Execute a graph RAG query with optional streaming callbacks. @@ -50,6 +51,7 @@ class GraphRagClient(RequestResponse): query = query, user = user, collection = collection, + parent_uri = parent_uri, ), timeout=timeout, recipient=recipient, diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index 304f17a7..e6ce0a9e 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -96,6 +96,7 @@ from . namespaces import ( TG_ANALYSIS, TG_CONCLUSION, # Unifying types TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + TG_TOOL_USE, # Question subtypes (to distinguish retrieval mechanism) TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, # Agent provenance predicates @@ -132,6 +133,7 @@ from . triples import ( from . agent import ( agent_session_triples, agent_iteration_triples, + agent_observation_triples, agent_final_triples, # Orchestrator provenance triple builders agent_decomposition_triples, @@ -210,6 +212,7 @@ __all__ = [ "TG_ANALYSIS", "TG_CONCLUSION", # Unifying types "TG_ANSWER_TYPE", "TG_REFLECTION_TYPE", "TG_THOUGHT_TYPE", "TG_OBSERVATION_TYPE", + "TG_TOOL_USE", # Question subtypes "TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION", # Agent provenance predicates @@ -238,6 +241,7 @@ __all__ = [ # Agent provenance triple builders "agent_session_triples", "agent_iteration_triples", + "agent_observation_triples", "agent_final_triples", # Orchestrator provenance triple builders "agent_decomposition_triples", diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py index d25109a7..4fc1f2b5 100644 --- a/trustgraph-base/trustgraph/provenance/agent.py +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -20,11 +20,12 @@ from .. schema import Triple, Term, IRI, LITERAL from . namespaces import ( RDF_TYPE, RDFS_LABEL, - PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, - PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME, - TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + PROV_ENTITY, PROV_WAS_DERIVED_FROM, + PROV_STARTED_AT_TIME, + TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, + TG_TOOL_USE, TG_AGENT_QUESTION, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP, @@ -70,7 +71,7 @@ def agent_session_triples( timestamp = datetime.utcnow().isoformat() + "Z" return [ - _triple(session_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + _triple(session_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(session_uri, RDF_TYPE, _iri(TG_QUESTION)), _triple(session_uri, RDF_TYPE, _iri(TG_AGENT_QUESTION)), _triple(session_uri, RDFS_LABEL, _literal("Agent Question")), @@ -87,19 +88,15 @@ def agent_iteration_triples( arguments: Dict[str, Any] = None, thought_uri: Optional[str] = None, thought_document_id: Optional[str] = None, - observation_uri: Optional[str] = None, - observation_document_id: Optional[str] = None, ) -> List[Triple]: """ - Build triples for one agent iteration (Analysis - think/act/observe cycle). + Build triples for one agent iteration (Analysis+ToolUse). Creates: - - Entity declaration with tg:Analysis type - - wasGeneratedBy link to question (if first iteration) - - wasDerivedFrom link to previous iteration (if not first) + - Entity declaration with tg:Analysis and tg:ToolUse types + - wasDerivedFrom link to question (if first iteration) or previous - Action and arguments metadata - Thought sub-entity (tg:Reflection, tg:Thought) with librarian document - - Observation sub-entity (tg:Reflection, tg:Observation) with librarian document Args: iteration_uri: URI of this iteration (from agent_iteration_uri) @@ -109,8 +106,6 @@ def agent_iteration_triples( arguments: Arguments passed to the tool (will be JSON-encoded) thought_uri: URI for the thought sub-entity thought_document_id: Document URI for thought in librarian - observation_uri: URI for the observation sub-entity - observation_document_id: Document URI for observation in librarian Returns: List of Triple objects @@ -121,6 +116,7 @@ def agent_iteration_triples( triples = [ _triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)), + _triple(iteration_uri, RDF_TYPE, _iri(TG_TOOL_USE)), _triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")), _triple(iteration_uri, TG_ACTION, _literal(action)), _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), @@ -128,7 +124,7 @@ def agent_iteration_triples( if question_uri: triples.append( - _triple(iteration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)) + _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri)) ) elif previous_uri: triples.append( @@ -142,26 +138,48 @@ def agent_iteration_triples( _triple(thought_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)), _triple(thought_uri, RDF_TYPE, _iri(TG_THOUGHT_TYPE)), _triple(thought_uri, RDFS_LABEL, _literal("Thought")), - _triple(thought_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)), + _triple(thought_uri, PROV_WAS_DERIVED_FROM, _iri(iteration_uri)), ]) if thought_document_id: triples.append( _triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id)) ) - # Observation sub-entity - if observation_uri: - triples.extend([ - _triple(iteration_uri, TG_OBSERVATION, _iri(observation_uri)), - _triple(observation_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)), - _triple(observation_uri, RDF_TYPE, _iri(TG_OBSERVATION_TYPE)), - _triple(observation_uri, RDFS_LABEL, _literal("Observation")), - _triple(observation_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)), - ]) - if observation_document_id: - triples.append( - _triple(observation_uri, TG_DOCUMENT, _iri(observation_document_id)) - ) + return triples + + +def agent_observation_triples( + observation_uri: str, + iteration_uri: str, + document_id: Optional[str] = None, +) -> List[Triple]: + """ + Build triples for an agent observation (standalone entity). + + Creates: + - Entity declaration with prov:Entity and tg:Observation types + - wasDerivedFrom link to the iteration (Analysis+ToolUse) + - Document reference to librarian (if provided) + + Args: + observation_uri: URI of the observation entity + iteration_uri: URI of the iteration this observation derives from + document_id: Librarian document ID for the observation content + + Returns: + List of Triple objects + """ + triples = [ + _triple(observation_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(observation_uri, RDF_TYPE, _iri(TG_OBSERVATION_TYPE)), + _triple(observation_uri, RDFS_LABEL, _literal("Observation")), + _triple(observation_uri, PROV_WAS_DERIVED_FROM, _iri(iteration_uri)), + ] + + if document_id: + triples.append( + _triple(observation_uri, TG_DOCUMENT, _iri(document_id)) + ) return triples @@ -199,7 +217,7 @@ def agent_final_triples( if question_uri: triples.append( - _triple(final_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)) + _triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri)) ) elif previous_uri: triples.append( @@ -223,7 +241,7 @@ def agent_decomposition_triples( _triple(uri, RDF_TYPE, _iri(TG_DECOMPOSITION)), _triple(uri, RDFS_LABEL, _literal(f"Decomposed into {len(goals)} research threads")), - _triple(uri, PROV_WAS_GENERATED_BY, _iri(session_uri)), + _triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)), ] for goal in goals: triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal))) @@ -261,7 +279,7 @@ def agent_plan_triples( _triple(uri, RDF_TYPE, _iri(TG_PLAN_TYPE)), _triple(uri, RDFS_LABEL, _literal(f"Plan with {len(steps)} steps")), - _triple(uri, PROV_WAS_GENERATED_BY, _iri(session_uri)), + _triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)), ] for step in steps: triples.append(_triple(uri, TG_PLAN_STEP, _literal(step))) diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 69134dfb..9e7fbb2d 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -105,6 +105,7 @@ TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion, F TG_REFLECTION_TYPE = TG + "Reflection" # Intermediate commentary (Thought, Observation) TG_THOUGHT_TYPE = TG + "Thought" # Agent reasoning TG_OBSERVATION_TYPE = TG + "Observation" # Agent tool result +TG_TOOL_USE = TG + "ToolUse" # Analysis+ToolUse mixin # Question subtypes (to distinguish retrieval mechanism) TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 407cab31..f2e85eff 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -353,18 +353,21 @@ def question_triples( question_uri: str, query: str, timestamp: Optional[str] = None, + parent_uri: Optional[str] = None, ) -> List[Triple]: """ - Build triples for a question activity. + Build triples for a question entity. Creates: - - Activity declaration for the question + - Entity declaration for the question - Query text and timestamp + - Optional wasDerivedFrom link to parent (for sub-traces) Args: question_uri: URI of the question (from question_uri) query: The user's query text timestamp: ISO timestamp (defaults to now) + parent_uri: Optional parent URI to link as wasDerivedFrom (for sub-traces) Returns: List of Triple objects @@ -372,8 +375,8 @@ def question_triples( if timestamp is None: timestamp = datetime.utcnow().isoformat() + "Z" - return [ - _triple(question_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + triples = [ + _triple(question_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(question_uri, RDF_TYPE, _iri(TG_QUESTION)), _triple(question_uri, RDF_TYPE, _iri(TG_GRAPH_RAG_QUESTION)), _triple(question_uri, RDFS_LABEL, _literal("GraphRAG Question")), @@ -381,6 +384,13 @@ def question_triples( _triple(question_uri, TG_QUERY, _literal(query)), ] + if parent_uri: + triples.append( + _triple(question_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)) + ) + + return triples + def grounding_triples( grounding_uri: str, @@ -407,7 +417,7 @@ def grounding_triples( _triple(grounding_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(grounding_uri, RDF_TYPE, _iri(TG_GROUNDING)), _triple(grounding_uri, RDFS_LABEL, _literal("Grounding")), - _triple(grounding_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)), + _triple(grounding_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri)), ] for concept in concepts: @@ -575,18 +585,21 @@ def docrag_question_triples( question_uri: str, query: str, timestamp: Optional[str] = None, + parent_uri: Optional[str] = None, ) -> List[Triple]: """ - Build triples for a document RAG question activity. + Build triples for a document RAG question entity. Creates: - - Activity declaration with tg:Question type + - Entity declaration with tg:Question type - Query text and timestamp + - Optional wasDerivedFrom link to parent (for sub-traces) Args: question_uri: URI of the question (from docrag_question_uri) query: The user's query text timestamp: ISO timestamp (defaults to now) + parent_uri: Optional parent URI to link as wasDerivedFrom (for sub-traces) Returns: List of Triple objects @@ -594,8 +607,8 @@ def docrag_question_triples( if timestamp is None: timestamp = datetime.utcnow().isoformat() + "Z" - return [ - _triple(question_uri, RDF_TYPE, _iri(PROV_ACTIVITY)), + triples = [ + _triple(question_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(question_uri, RDF_TYPE, _iri(TG_QUESTION)), _triple(question_uri, RDF_TYPE, _iri(TG_DOC_RAG_QUESTION)), _triple(question_uri, RDFS_LABEL, _literal("DocumentRAG Question")), @@ -603,6 +616,13 @@ def docrag_question_triples( _triple(question_uri, TG_QUERY, _literal(query)), ] + if parent_uri: + triples.append( + _triple(question_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)) + ) + + return triples + def docrag_exploration_triples( exploration_uri: str, diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index d4f76655..f5ac73d3 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -18,6 +18,7 @@ class GraphRagQuery: edge_score_limit: int = 0 edge_limit: int = 0 streaming: bool = False + parent_uri: str = "" @dataclass class GraphRagResponse: diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index c82c78f6..1c4b757b 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -12,6 +12,7 @@ from trustgraph.api import ( ProvenanceEvent, Question, Analysis, + Observation, Conclusion, Decomposition, Finding, @@ -206,13 +207,13 @@ def question_explainable( print(f" Time: {entity.timestamp}", file=sys.stderr) elif isinstance(entity, Analysis): - print(f"\n [iteration] {prov_id}", file=sys.stderr) - if entity.action: - print(f" Action: {entity.action}", file=sys.stderr) - if entity.thought: - print(f" Thought: {entity.thought}", file=sys.stderr) - if entity.observation: - print(f" Observation: {entity.observation}", file=sys.stderr) + action_label = f": {entity.action}" if entity.action else "" + print(f"\n [analysis{action_label}] {prov_id}", file=sys.stderr) + + elif isinstance(entity, Observation): + print(f"\n [observation] {prov_id}", file=sys.stderr) + if entity.document: + print(f" Document: {entity.document}", file=sys.stderr) elif isinstance(entity, Decomposition): print(f"\n [decompose] {prov_id}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index c4da0d5a..90c0e452 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -26,6 +26,7 @@ from trustgraph.api import ( Focus, Synthesis, Analysis, + Observation, Conclusion, Decomposition, Finding, @@ -379,11 +380,13 @@ def print_agent_text(trace, explain_client, api, user): print(f" {line}") except Exception: print(f" Arguments: {step.arguments}") + print() - obs = step.observation or 'N/A' - if obs and len(obs) > 200: - obs = obs[:200] + "... [truncated]" - print(f" Observation: {obs}") + elif isinstance(step, Observation): + print("--- Observation ---") + _print_document_content( + explain_client, api, user, step.document, "Content", + ) print() elif isinstance(step, Synthesis): @@ -437,6 +440,12 @@ def trace_to_dict(trace, trace_type): "step": step.step, "document": step.document, } + elif isinstance(step, Observation): + return { + "type": "observation", + "id": step.uri, + "document": step.document, + } elif isinstance(step, Analysis): return { "type": "analysis", @@ -444,7 +453,6 @@ def trace_to_dict(trace, trace_type): "action": step.action, "arguments": step.arguments, "thought": step.thought, - "observation": step.observation, } elif isinstance(step, Synthesis): return { diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index 4faa7ce6..f999b132 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -27,6 +27,7 @@ from trustgraph.provenance import ( agent_synthesis_uri, agent_session_triples, agent_iteration_triples, + agent_observation_triples, agent_final_triples, agent_decomposition_triples, agent_finding_triples, @@ -46,9 +47,12 @@ logger = logging.getLogger(__name__) class UserAwareContext: """Wraps flow interface to inject user context for tools that need it.""" - def __init__(self, flow, user): + def __init__(self, flow, user, respond=None, streaming=False): self._flow = flow self._user = user + self.respond = respond + self.streaming = streaming + self.current_explain_uri = None def __call__(self, service_name): client = self._flow(service_name) @@ -120,9 +124,9 @@ class PatternBase: current_state=getattr(request, 'state', None), ) - def make_context(self, flow, user): + def make_context(self, flow, user, respond=None, streaming=False): """Create a user-aware context wrapper.""" - return UserAwareContext(flow, user) + return UserAwareContext(flow, user, respond=respond, streaming=streaming) def build_history(self, request): """Convert AgentStep history into Action objects.""" @@ -140,7 +144,7 @@ class PatternBase: # ---- Streaming callbacks ------------------------------------------------ - def make_think_callback(self, respond, streaming): + def make_think_callback(self, respond, streaming, message_id=""): """Create the think callback for streaming/non-streaming.""" async def think(x, is_final=False): logger.debug(f"Think: {x} (is_final={is_final})") @@ -150,6 +154,7 @@ class PatternBase: content=x, end_of_message=is_final, end_of_dialog=False, + message_id=message_id, ) else: r = AgentResponse( @@ -157,11 +162,12 @@ class PatternBase: content=x, end_of_message=True, end_of_dialog=False, + message_id=message_id, ) await respond(r) return think - def make_observe_callback(self, respond, streaming): + def make_observe_callback(self, respond, streaming, message_id=""): """Create the observe callback for streaming/non-streaming.""" async def observe(x, is_final=False): logger.debug(f"Observe: {x} (is_final={is_final})") @@ -171,6 +177,7 @@ class PatternBase: content=x, end_of_message=is_final, end_of_dialog=False, + message_id=message_id, ) else: r = AgentResponse( @@ -178,6 +185,7 @@ class PatternBase: content=x, end_of_message=True, end_of_dialog=False, + message_id=message_id, ) await respond(r) return observe @@ -223,23 +231,23 @@ class PatternBase: )) logger.debug(f"Emitted session triples for {session_uri}") - if streaming: - await respond(AgentResponse( - chunk_type="explain", - content="", - explain_id=session_uri, - explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=session_uri, + explain_graph=GRAPH_RETRIEVAL, + )) async def emit_iteration_triples(self, flow, session_id, iteration_num, session_uri, act, request, respond, streaming): - """Emit provenance triples for an iteration and save to librarian.""" + """Emit provenance triples for an iteration (Analysis+ToolUse).""" iteration_uri = agent_iteration_uri(session_id, iteration_num) if iteration_num > 1: + # Chain through previous Observation (last entity in prior cycle) iter_question_uri = None - iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) + iter_previous_uri = agent_observation_uri(session_id, iteration_num - 1) else: iter_question_uri = session_uri iter_previous_uri = None @@ -261,25 +269,7 @@ class PatternBase: logger.warning(f"Failed to save thought to librarian: {e}") thought_doc_id = None - # Save observation to librarian - observation_doc_id = None - if act.observation: - observation_doc_id = ( - f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" - ) - try: - await self.processor.save_answer_content( - doc_id=observation_doc_id, - user=request.user, - content=act.observation, - title=f"Agent Observation: {act.name}", - ) - except Exception as e: - logger.warning(f"Failed to save observation to librarian: {e}") - observation_doc_id = None - thought_entity_uri = agent_thought_uri(session_id, iteration_num) - observation_entity_uri = agent_observation_uri(session_id, iteration_num) iter_triples = set_graph( agent_iteration_triples( @@ -290,8 +280,6 @@ class PatternBase: arguments=act.arguments, thought_uri=thought_entity_uri if thought_doc_id else None, thought_document_id=thought_doc_id, - observation_uri=observation_entity_uri if observation_doc_id else None, - observation_document_id=observation_doc_id, ), GRAPH_RETRIEVAL, ) @@ -305,13 +293,60 @@ class PatternBase: )) logger.debug(f"Emitted iteration triples for {iteration_uri}") - if streaming: - await respond(AgentResponse( - chunk_type="explain", - content="", - explain_id=iteration_uri, - explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=iteration_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + + async def emit_observation_triples(self, flow, session_id, iteration_num, + observation_text, request, respond): + """Emit provenance triples for a standalone Observation entity.""" + iteration_uri = agent_iteration_uri(session_id, iteration_num) + observation_entity_uri = agent_observation_uri(session_id, iteration_num) + + # Save observation to librarian + observation_doc_id = None + if observation_text: + observation_doc_id = ( + f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" + ) + try: + await self.processor.save_answer_content( + doc_id=observation_doc_id, + user=request.user, + content=observation_text, + title=f"Agent Observation", + ) + except Exception as e: + logger.warning(f"Failed to save observation to librarian: {e}") + observation_doc_id = None + + obs_triples = set_graph( + agent_observation_triples( + observation_entity_uri, + iteration_uri, + document_id=observation_doc_id, + ), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=observation_entity_uri, + user=request.user, + collection=getattr(request, 'collection', 'default'), + ), + triples=obs_triples, + )) + logger.debug(f"Emitted observation triples for {observation_entity_uri}") + + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=observation_entity_uri, + explain_graph=GRAPH_RETRIEVAL, + )) async def emit_final_triples(self, flow, session_id, iteration_num, session_uri, answer_text, request, respond, @@ -320,8 +355,9 @@ class PatternBase: final_uri = agent_final_uri(session_id) if iteration_num > 1: + # Chain through last Observation (last entity in prior cycle) final_question_uri = None - final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) + final_previous_uri = agent_observation_uri(session_id, iteration_num - 1) else: final_question_uri = session_uri final_previous_uri = None @@ -361,13 +397,12 @@ class PatternBase: )) logger.debug(f"Emitted final triples for {final_uri}") - if streaming: - await respond(AgentResponse( - chunk_type="explain", - content="", - explain_id=final_uri, - explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=final_uri, + explain_graph=GRAPH_RETRIEVAL, + )) # ---- Orchestrator provenance helpers ------------------------------------ @@ -385,11 +420,10 @@ class PatternBase: metadata=Metadata(id=uri, user=user, collection=collection), triples=triples, )) - if streaming: - await respond(AgentResponse( - chunk_type="explain", content="", - explain_id=uri, explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) async def emit_finding_triples( self, flow, session_id, index, goal, answer_text, user, collection, @@ -418,11 +452,10 @@ class PatternBase: metadata=Metadata(id=uri, user=user, collection=collection), triples=triples, )) - if streaming: - await respond(AgentResponse( - chunk_type="explain", content="", - explain_id=uri, explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) async def emit_plan_triples( self, flow, session_id, session_uri, steps, user, collection, @@ -438,11 +471,10 @@ class PatternBase: metadata=Metadata(id=uri, user=user, collection=collection), triples=triples, )) - if streaming: - await respond(AgentResponse( - chunk_type="explain", content="", - explain_id=uri, explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) async def emit_step_result_triples( self, flow, session_id, index, goal, answer_text, user, collection, @@ -471,11 +503,10 @@ class PatternBase: metadata=Metadata(id=uri, user=user, collection=collection), triples=triples, )) - if streaming: - await respond(AgentResponse( - chunk_type="explain", content="", - explain_id=uri, explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) async def emit_synthesis_triples( self, flow, session_id, previous_uri, answer_text, user, collection, @@ -503,11 +534,10 @@ class PatternBase: metadata=Metadata(id=uri, user=user, collection=collection), triples=triples, )) - if streaming: - await respond(AgentResponse( - chunk_type="explain", content="", - explain_id=uri, explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + )) # ---- Response helpers --------------------------------------------------- diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index d6abb058..4775212e 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -11,7 +11,11 @@ import uuid from ... schema import AgentRequest, AgentResponse, AgentStep, PlanStep - +from trustgraph.provenance import ( + agent_step_result_uri as make_step_result_uri, + agent_thought_uri, + agent_observation_uri, +) from . pattern_base import PatternBase @@ -101,7 +105,10 @@ class PlanThenExecutePattern(PatternBase): tools = self.filter_tools(self.processor.agent.tools, request) framing = getattr(request, 'framing', '') - context = self.make_context(flow, request.user) + context = self.make_context( + flow, request.user, + respond=respond, streaming=streaming, + ) client = context("prompt-request") # Use the plan-create prompt template @@ -198,8 +205,11 @@ class PlanThenExecutePattern(PatternBase): logger.info(f"Executing plan step {pending_idx}: {goal}") - think = self.make_think_callback(respond, streaming) - observe = self.make_observe_callback(respond, streaming) + thought_msg_id = agent_thought_uri(session_id, iteration_num) + observation_msg_id = agent_observation_uri(session_id, iteration_num) + + think = self.make_think_callback(respond, streaming, message_id=thought_msg_id) + observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id) # Gather results from dependencies previous_results = [] @@ -216,7 +226,16 @@ class PlanThenExecutePattern(PatternBase): }) tools = self.filter_tools(self.processor.agent.tools, request) - context = self.make_context(flow, request.user) + context = self.make_context( + flow, request.user, + respond=respond, streaming=streaming, + ) + + # Set current explain URI so tools can link sub-traces + context.current_explain_uri = make_step_result_uri( + session_id, pending_idx, + ) + client = context("prompt-request") # Single-shot: ask LLM which tool + arguments to use for this goal @@ -316,7 +335,10 @@ class PlanThenExecutePattern(PatternBase): think = self.make_think_callback(respond, streaming) framing = getattr(request, 'framing', '') - context = self.make_context(flow, request.user) + context = self.make_context( + flow, request.user, + respond=respond, streaming=streaming, + ) client = context("prompt-request") # Use the plan-synthesise prompt template @@ -342,8 +364,7 @@ class PlanThenExecutePattern(PatternBase): ) # Emit synthesis provenance (links back to last step result) - from trustgraph.provenance import agent_step_result_uri - last_step_uri = agent_step_result_uri(session_id, len(plan) - 1) + last_step_uri = make_step_result_uri(session_id, len(plan) - 1) await self.emit_synthesis_triples( flow, session_id, last_step_uri, response_text, request.user, collection, respond, streaming, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index 32261809..f6af65c2 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -11,6 +11,12 @@ import uuid from ... schema import AgentRequest, AgentResponse, AgentStep +from trustgraph.provenance import ( + agent_iteration_uri, + agent_thought_uri, + agent_observation_uri, +) + from ..react.agent_manager import AgentManager from ..react.types import Action, Final from ..tool_filter import get_next_state @@ -51,9 +57,13 @@ class ReactPattern(PatternBase): if len(history) >= self.processor.max_iterations: raise RuntimeError("Too many agent iterations") + # Compute URIs upfront for message_id + thought_msg_id = agent_thought_uri(session_id, iteration_num) + observation_msg_id = agent_observation_uri(session_id, iteration_num) + # Build callbacks - think = self.make_think_callback(respond, streaming) - observe = self.make_observe_callback(respond, streaming) + think = self.make_think_callback(respond, streaming, message_id=thought_msg_id) + observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id) answer_cb = self.make_answer_callback(respond, streaming) # Filter tools @@ -75,7 +85,22 @@ class ReactPattern(PatternBase): additional_context=additional_context, ) - context = self.make_context(flow, request.user) + context = self.make_context( + flow, request.user, + respond=respond, streaming=streaming, + ) + + # Set current explain URI so tools can link sub-traces + context.current_explain_uri = agent_iteration_uri( + session_id, iteration_num, + ) + + # Callback: emit Analysis+ToolUse triples before tool executes + async def on_action(act): + await self.emit_iteration_triples( + flow, session_id, iteration_num, session_uri, + act, request, respond, streaming, + ) act = await temp_agent.react( question=request.question, @@ -85,6 +110,7 @@ class ReactPattern(PatternBase): answer=answer_cb, context=context, streaming=streaming, + on_action=on_action, ) logger.debug(f"Action: {act}") @@ -110,10 +136,10 @@ class ReactPattern(PatternBase): ) return - # Not final — emit iteration provenance and send next request - await self.emit_iteration_triples( - flow, session_id, iteration_num, session_uri, - act, request, respond, streaming, + # Emit observation provenance after tool execution + await self.emit_observation_triples( + flow, session_id, iteration_num, + act.observation, request, respond, ) history.append(act) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 8588e400..951063cf 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -86,7 +86,10 @@ class SupervisorPattern(PatternBase): tools = self.filter_tools(self.processor.agent.tools, request) - context = self.make_context(flow, request.user) + context = self.make_context( + flow, request.user, + respond=respond, streaming=streaming, + ) client = context("prompt-request") # Use the supervisor-decompose prompt template @@ -182,7 +185,10 @@ class SupervisorPattern(PatternBase): logger.warning("Synthesis called with no subagent results") subagent_results = {"(no results)": "No subagent results available"} - context = self.make_context(flow, request.user) + context = self.make_context( + flow, request.user, + respond=respond, streaming=streaming, + ) client = context("prompt-request") await think("Synthesising final answer from sub-agent results", is_final=True) diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index 18598b38..e86a2d6c 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -291,7 +291,8 @@ class AgentManager: logger.error(f"Response was: {response_text}") raise RuntimeError(f"Failed to parse agent response: {e}") - async def react(self, question, history, think, observe, context, streaming=False, answer=None): + async def react(self, question, history, think, observe, context, + streaming=False, answer=None, on_action=None): act = await self.reason( question = question, @@ -325,6 +326,10 @@ class AgentManager: else: raise RuntimeError(f"No action for {act.name}!") + # Notify caller before tool execution (for provenance) + if on_action: + await on_action(act) + resp = await action.implementation(context).invoke( **act.arguments ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 6c06f71a..af088ec9 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -36,6 +36,7 @@ from trustgraph.provenance import ( agent_final_uri, agent_session_triples, agent_iteration_triples, + agent_observation_triples, agent_final_triples, set_graph, GRAPH_RETRIEVAL, @@ -465,13 +466,12 @@ class Processor(AgentService): logger.debug(f"Emitted session triples for {session_uri}") # Send explain event for session - if streaming: - await respond(AgentResponse( - chunk_type="explain", - content="", - explain_id=session_uri, - explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=session_uri, + explain_graph=GRAPH_RETRIEVAL, + )) logger.info(f"Question: {request.question}") @@ -480,6 +480,9 @@ class Processor(AgentService): logger.debug(f"History: {history}") + thought_msg_id = agent_thought_uri(session_id, iteration_num) + observation_msg_id = agent_observation_uri(session_id, iteration_num) + async def think(x, is_final=False): logger.debug(f"Think: {x} (is_final={is_final})") @@ -490,6 +493,7 @@ class Processor(AgentService): content=x, end_of_message=is_final, end_of_dialog=False, + message_id=thought_msg_id, ) else: r = AgentResponse( @@ -497,6 +501,7 @@ class Processor(AgentService): content=x, end_of_message=True, end_of_dialog=False, + message_id=thought_msg_id, ) await respond(r) @@ -511,6 +516,7 @@ class Processor(AgentService): content=x, end_of_message=is_final, end_of_dialog=False, + message_id=observation_msg_id, ) else: r = AgentResponse( @@ -518,6 +524,7 @@ class Processor(AgentService): content=x, end_of_message=True, end_of_dialog=False, + message_id=observation_msg_id, ) await respond(r) @@ -572,6 +579,62 @@ class Processor(AgentService): client._current_user = self._user return client + # Callback: emit Analysis+ToolUse triples before tool executes + async def on_action(act_decision): + iter_uri = agent_iteration_uri(session_id, iteration_num) + if iteration_num > 1: + iter_q_uri = None + iter_prev_uri = agent_observation_uri(session_id, iteration_num - 1) + else: + iter_q_uri = session_uri + iter_prev_uri = None + + # Save thought to librarian + t_doc_id = None + if act_decision.thought: + t_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" + try: + await self.save_answer_content( + doc_id=t_doc_id, + user=request.user, + content=act_decision.thought, + title=f"Agent Thought: {act_decision.name}", + ) + except Exception as e: + logger.warning(f"Failed to save thought to librarian: {e}") + t_doc_id = None + + t_entity_uri = agent_thought_uri(session_id, iteration_num) + + iter_triples = set_graph( + agent_iteration_triples( + iter_uri, + question_uri=iter_q_uri, + previous_uri=iter_prev_uri, + action=act_decision.name, + arguments=act_decision.arguments, + thought_uri=t_entity_uri if t_doc_id else None, + thought_document_id=t_doc_id, + ), + GRAPH_RETRIEVAL + ) + await flow("explainability").send(Triples( + metadata=Metadata( + id=iter_uri, + user=request.user, + collection=collection, + ), + triples=iter_triples, + )) + logger.debug(f"Emitted iteration triples for {iter_uri}") + + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=iter_uri, + explain_graph=GRAPH_RETRIEVAL, + )) + act = await temp_agent.react( question = request.question, history = history, @@ -580,6 +643,7 @@ class Processor(AgentService): answer = answer, context = UserAwareContext(flow, request.user), streaming = streaming, + on_action = on_action, ) logger.debug(f"Action: {act}") @@ -595,10 +659,10 @@ class Processor(AgentService): # Emit final answer provenance triples final_uri = agent_final_uri(session_id) - # No iterations: link to question; otherwise: link to last iteration + # No iterations: link to question; otherwise: link to last observation if iteration_num > 1: final_question_uri = None - final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) + final_previous_uri = agent_observation_uri(session_id, iteration_num - 1) else: final_question_uri = session_uri final_previous_uri = None @@ -639,13 +703,12 @@ class Processor(AgentService): logger.debug(f"Emitted final triples for {final_uri}") # Send explain event for conclusion - if streaming: - await respond(AgentResponse( - chunk_type="explain", - content="", - explain_id=final_uri, - explain_graph=GRAPH_RETRIEVAL, - )) + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=final_uri, + explain_graph=GRAPH_RETRIEVAL, + )) if streaming: # End-of-dialog marker — answer chunks already sent via callback @@ -671,33 +734,9 @@ class Processor(AgentService): logger.debug("Send next...") - # Emit iteration provenance triples + # Emit standalone observation provenance (iteration was emitted in on_action) iteration_uri = agent_iteration_uri(session_id, iteration_num) - # First iteration links to question, subsequent to previous - if iteration_num > 1: - iter_question_uri = None - iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) - else: - iter_question_uri = session_uri - iter_previous_uri = None - - # Save thought to librarian - thought_doc_id = None - if act.thought: - thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" - try: - await self.save_answer_content( - doc_id=thought_doc_id, - user=request.user, - content=act.thought, - title=f"Agent Thought: {act.name}", - ) - logger.debug(f"Saved thought to librarian: {thought_doc_id}") - except Exception as e: - logger.warning(f"Failed to save thought to librarian: {e}") - thought_doc_id = None - - # Save observation to librarian + observation_entity_uri = agent_observation_uri(session_id, iteration_num) observation_doc_id = None if act.observation: observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" @@ -706,48 +745,38 @@ class Processor(AgentService): doc_id=observation_doc_id, user=request.user, content=act.observation, - title=f"Agent Observation: {act.name}", + title=f"Agent Observation", ) logger.debug(f"Saved observation to librarian: {observation_doc_id}") except Exception as e: logger.warning(f"Failed to save observation to librarian: {e}") observation_doc_id = None - thought_entity_uri = agent_thought_uri(session_id, iteration_num) - observation_entity_uri = agent_observation_uri(session_id, iteration_num) - - iter_triples = set_graph( - agent_iteration_triples( + obs_triples = set_graph( + agent_observation_triples( + observation_entity_uri, iteration_uri, - question_uri=iter_question_uri, - previous_uri=iter_previous_uri, - action=act.name, - arguments=act.arguments, - thought_uri=thought_entity_uri if thought_doc_id else None, - thought_document_id=thought_doc_id, - observation_uri=observation_entity_uri if observation_doc_id else None, - observation_document_id=observation_doc_id, + document_id=observation_doc_id, ), GRAPH_RETRIEVAL ) await flow("explainability").send(Triples( metadata=Metadata( - id=iteration_uri, + id=observation_entity_uri, user=request.user, collection=collection, ), - triples=iter_triples, + triples=obs_triples, )) - logger.debug(f"Emitted iteration triples for {iteration_uri}") + logger.debug(f"Emitted observation triples for {observation_entity_uri}") - # Send explain event for iteration - if streaming: - await respond(AgentResponse( - chunk_type="explain", - content="", - explain_id=iteration_uri, - explain_graph=GRAPH_RETRIEVAL, - )) + # Send explain event for observation + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=observation_entity_uri, + explain_graph=GRAPH_RETRIEVAL, + )) history.append(act) diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 441c8f38..86b515e1 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -12,7 +12,7 @@ class KnowledgeQueryImpl: def __init__(self, context, collection=None): self.context = context self.collection = collection - + @staticmethod def get_arguments(): return [ @@ -22,13 +22,39 @@ class KnowledgeQueryImpl: description="The question to ask the knowledge base" ) ] - + async def invoke(self, **arguments): client = self.context("graph-rag-request") logger.debug("Graph RAG question...") + + # Build explain_callback to forward sub-trace explain events + # to the agent's response stream + explain_callback = None + parent_uri = "" + + respond = getattr(self.context, 'respond', None) + streaming = getattr(self.context, 'streaming', False) + current_uri = getattr(self.context, 'current_explain_uri', None) + + if respond: + from ... schema import AgentResponse + + async def explain_callback(explain_id, explain_graph): + await respond(AgentResponse( + chunk_type="explain", + content="", + explain_id=explain_id, + explain_graph=explain_graph, + )) + + if current_uri: + parent_uri = current_uri + return await client.rag( arguments.get("question"), - collection=self.collection if self.collection else "default" + collection=self.collection if self.collection else "default", + explain_callback=explain_callback, + parent_uri=parent_uri, ) # This tool implementation knows how to do text completion. This uses diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index ea9326a4..704613c6 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -555,6 +555,7 @@ class GraphRag: streaming = False, chunk_callback = None, explain_callback = None, save_answer_callback = None, + parent_uri = "", ): """ Execute a GraphRAG query with real-time explainability tracking. @@ -593,7 +594,10 @@ class GraphRag: # Emit question explainability immediately if explain_callback: q_triples = set_graph( - question_triples(q_uri, query, timestamp), + question_triples( + q_uri, query, timestamp, + parent_uri=parent_uri or None, + ), GRAPH_RETRIEVAL ) await explain_callback(q_triples, q_uri) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index c3244b90..85a7491e 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -342,6 +342,7 @@ class Processor(FlowProcessor): chunk_callback = send_chunk, explain_callback = send_explainability, save_answer_callback = save_answer, + parent_uri = v.parent_uri, ) else: @@ -355,6 +356,7 @@ class Processor(FlowProcessor): edge_limit = edge_limit, explain_callback = send_explainability, save_answer_callback = save_answer, + parent_uri = v.parent_uri, ) # Send chunk with response From 2bcf375103700349b36bf2ddf7e84d8e4aa897d5 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Apr 2026 13:27:41 +0100 Subject: [PATCH 22/37] Wire message_id on all answer chunks, fix DAG structure (#748) Wire message_id on all answer chunks, fix DAG structure message_id: - Add message_id to AgentAnswer dataclass and propagate in socket_client._parse_chunk - Wire message_id into answer callbacks and send_final_response for all three patterns (react, plan-then-execute, supervisor) - Supervisor decomposition thought and synthesis answer chunks now carry message_id DAG structure fixes: - Observation derives from sub-trace Synthesis (not Analysis) when a tool produces a sub-trace; tracked via last_sub_explain_uri on context - Subagent sessions derive from parent's Decomposition via parent_uri on agent_session_triples - Findings derive from subagent Conclusions (not Decomposition) - Synthesis derives from all findings (multiple wasDerivedFrom) ensuring single terminal node - agent_synthesis_triples accepts list of parent URIs - Explainability chain walker follows from sub-trace terminal to find downstream Observation Emit Analysis before tool execution: - Add on_action callback to react() in agent_manager.py, called after reason() but before tool invocation - Orchestrator and old service emit Analysis+ToolUse triples via on_action so sub-traces appear after their parent in the stream --- .../trustgraph/api/explainability.py | 9 ++++ .../trustgraph/api/socket_client.py | 3 +- trustgraph-base/trustgraph/api/types.py | 1 + .../trustgraph/provenance/agent.py | 29 +++++++++-- .../agent/orchestrator/pattern_base.py | 49 ++++++++++++++----- .../agent/orchestrator/plan_pattern.py | 5 ++ .../agent/orchestrator/react_pattern.py | 14 +++++- .../trustgraph/agent/orchestrator/service.py | 3 ++ .../agent/orchestrator/supervisor_pattern.py | 29 ++++++++--- .../trustgraph/agent/react/service.py | 19 ++++++- .../trustgraph/agent/react/tools.py | 1 + 11 files changed, 134 insertions(+), 28 deletions(-) diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index fa6c4a0c..08d0b4e7 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -1095,6 +1095,15 @@ class ExplainabilityClient: "trace": sub_trace, }) + # Continue from the sub-trace's terminal entity + # (Observation may derive from Synthesis) + terminal = sub_trace.get("synthesis") + if terminal: + self._follow_provenance_chain( + terminal.uri, trace, graph, user, collection, + max_depth=max_depth - 1, + ) + elif isinstance(entity, (Conclusion, Synthesis)): trace["steps"].append(entity) diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 3b463762..847513d3 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -397,7 +397,8 @@ class SocketClient: return AgentAnswer( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), - end_of_dialog=resp.get("end_of_dialog", False) + end_of_dialog=resp.get("end_of_dialog", False), + message_id=resp.get("message_id", ""), ) elif chunk_type == "action": return AgentThought( diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 3e3f1520..0715293b 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -188,6 +188,7 @@ class AgentAnswer(StreamingChunk): """ chunk_type: str = "final-answer" end_of_dialog: bool = False + message_id: str = "" @dataclasses.dataclass class RAGChunk(StreamingChunk): diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py index 4fc1f2b5..7203174e 100644 --- a/trustgraph-base/trustgraph/provenance/agent.py +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -51,6 +51,7 @@ def agent_session_triples( session_uri: str, query: str, timestamp: Optional[str] = None, + parent_uri: Optional[str] = None, ) -> List[Triple]: """ Build triples for an agent session start (Question). @@ -58,11 +59,13 @@ def agent_session_triples( Creates: - Activity declaration with tg:Question type - Query text and timestamp + - wasDerivedFrom link to parent (for subagent sessions) Args: session_uri: URI of the session (from agent_session_uri) query: The user's query text timestamp: ISO timestamp (defaults to now) + parent_uri: URI of the parent entity (e.g. Decomposition) for subagents Returns: List of Triple objects @@ -70,7 +73,7 @@ def agent_session_triples( if timestamp is None: timestamp = datetime.utcnow().isoformat() + "Z" - return [ + triples = [ _triple(session_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(session_uri, RDF_TYPE, _iri(TG_QUESTION)), _triple(session_uri, RDF_TYPE, _iri(TG_AGENT_QUESTION)), @@ -79,6 +82,13 @@ def agent_session_triples( _triple(session_uri, TG_QUERY, _literal(query)), ] + if parent_uri: + triples.append( + _triple(session_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)) + ) + + return triples + def agent_iteration_triples( iteration_uri: str, @@ -308,17 +318,28 @@ def agent_step_result_triples( def agent_synthesis_triples( uri: str, - previous_uri: str, + previous_uris, document_id: Optional[str] = None, ) -> List[Triple]: - """Build triples for a synthesis answer.""" + """Build triples for a synthesis answer. + + Args: + uri: URI of the synthesis entity + previous_uris: Single URI string or list of URIs to derive from + document_id: Librarian document ID for the answer content + """ triples = [ _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(uri, RDF_TYPE, _iri(TG_SYNTHESIS)), _triple(uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), _triple(uri, RDFS_LABEL, _literal("Synthesis")), - _triple(uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri)), ] + + if isinstance(previous_uris, str): + previous_uris = [previous_uris] + for prev in previous_uris: + triples.append(_triple(uri, PROV_WAS_DERIVED_FROM, _iri(prev))) + if document_id: triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) return triples diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index f999b132..8849a206 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -53,6 +53,7 @@ class UserAwareContext: self.respond = respond self.streaming = streaming self.current_explain_uri = None + self.last_sub_explain_uri = None def __call__(self, service_name): client = self._flow(service_name) @@ -190,7 +191,7 @@ class PatternBase: await respond(r) return observe - def make_answer_callback(self, respond, streaming): + def make_answer_callback(self, respond, streaming, message_id=""): """Create the answer callback for streaming/non-streaming.""" async def answer(x): logger.debug(f"Answer: {x}") @@ -200,6 +201,7 @@ class PatternBase: content=x, end_of_message=False, end_of_dialog=False, + message_id=message_id, ) else: r = AgentResponse( @@ -207,6 +209,7 @@ class PatternBase: content=x, end_of_message=True, end_of_dialog=False, + message_id=message_id, ) await respond(r) return answer @@ -214,11 +217,15 @@ class PatternBase: # ---- Provenance emission ------------------------------------------------ async def emit_session_triples(self, flow, session_uri, question, user, - collection, respond, streaming): + collection, respond, streaming, + parent_uri=None): """Emit provenance triples for a new session.""" timestamp = datetime.utcnow().isoformat() + "Z" triples = set_graph( - agent_session_triples(session_uri, question, timestamp), + agent_session_triples( + session_uri, question, timestamp, + parent_uri=parent_uri, + ), GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( @@ -301,11 +308,18 @@ class PatternBase: )) async def emit_observation_triples(self, flow, session_id, iteration_num, - observation_text, request, respond): + observation_text, request, respond, + context=None): """Emit provenance triples for a standalone Observation entity.""" iteration_uri = agent_iteration_uri(session_id, iteration_num) observation_entity_uri = agent_observation_uri(session_id, iteration_num) + # Derive from the last sub-trace entity if available (e.g. Synthesis), + # otherwise fall back to the iteration (Analysis+ToolUse). + parent_uri = iteration_uri + if context and getattr(context, 'last_sub_explain_uri', None): + parent_uri = context.last_sub_explain_uri + # Save observation to librarian observation_doc_id = None if observation_text: @@ -326,7 +340,7 @@ class PatternBase: obs_triples = set_graph( agent_observation_triples( observation_entity_uri, - iteration_uri, + parent_uri, document_id=observation_doc_id, ), GRAPH_RETRIEVAL, @@ -427,11 +441,17 @@ class PatternBase: async def emit_finding_triples( self, flow, session_id, index, goal, answer_text, user, collection, - respond, streaming, + respond, streaming, subagent_session_id="", ): """Emit provenance for a subagent finding.""" uri = agent_finding_uri(session_id, index) - decomposition_uri = agent_decomposition_uri(session_id) + + # Derive from the subagent's conclusion if available, + # otherwise fall back to the decomposition. + if subagent_session_id: + parent_uri = agent_final_uri(subagent_session_id) + else: + parent_uri = agent_decomposition_uri(session_id) doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc" try: @@ -445,7 +465,7 @@ class PatternBase: doc_id = None triples = set_graph( - agent_finding_triples(uri, decomposition_uri, goal, doc_id), + agent_finding_triples(uri, parent_uri, goal, doc_id), GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( @@ -509,7 +529,7 @@ class PatternBase: )) async def emit_synthesis_triples( - self, flow, session_id, previous_uri, answer_text, user, collection, + self, flow, session_id, previous_uris, answer_text, user, collection, respond, streaming, ): """Emit provenance for a synthesis answer.""" @@ -527,7 +547,7 @@ class PatternBase: doc_id = None triples = set_graph( - agent_synthesis_triples(uri, previous_uri, doc_id), + agent_synthesis_triples(uri, previous_uris, doc_id), GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( @@ -542,7 +562,7 @@ class PatternBase: # ---- Response helpers --------------------------------------------------- async def prompt_as_answer(self, client, prompt_id, variables, - respond, streaming): + respond, streaming, message_id=""): """Call a prompt template, forwarding chunks as answer AgentResponse messages when streaming is enabled. @@ -559,6 +579,7 @@ class PatternBase: content=text, end_of_message=False, end_of_dialog=False, + message_id=message_id, )) await client.prompt( @@ -576,13 +597,14 @@ class PatternBase: ) async def send_final_response(self, respond, streaming, answer_text, - already_streamed=False): + already_streamed=False, message_id=""): """Send the answer content and end-of-dialog marker. Args: already_streamed: If True, answer chunks were already sent via streaming callbacks (e.g. ReactPattern). Only the end-of-dialog marker is emitted. + message_id: Provenance URI for the answer entity. """ if streaming and not already_streamed: # Answer wasn't streamed yet — send it as a chunk first @@ -592,6 +614,7 @@ class PatternBase: content=answer_text, end_of_message=False, end_of_dialog=False, + message_id=message_id, )) if streaming: # End-of-dialog marker @@ -600,6 +623,7 @@ class PatternBase: content="", end_of_message=True, end_of_dialog=True, + message_id=message_id, )) else: await respond(AgentResponse( @@ -607,6 +631,7 @@ class PatternBase: content=answer_text, end_of_message=True, end_of_dialog=True, + message_id=message_id, )) def build_next_request(self, request, history, session_id, collection, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 4775212e..59d22929 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -15,6 +15,7 @@ from trustgraph.provenance import ( agent_step_result_uri as make_step_result_uri, agent_thought_uri, agent_observation_uri, + agent_synthesis_uri, ) from . pattern_base import PatternBase @@ -352,6 +353,8 @@ class PlanThenExecutePattern(PatternBase): await think("Synthesising final answer from plan results", is_final=True) + synthesis_msg_id = agent_synthesis_uri(session_id) + response_text = await self.prompt_as_answer( client, "plan-synthesise", variables={ @@ -361,6 +364,7 @@ class PlanThenExecutePattern(PatternBase): }, respond=respond, streaming=streaming, + message_id=synthesis_msg_id, ) # Emit synthesis provenance (links back to last step result) @@ -375,4 +379,5 @@ class PlanThenExecutePattern(PatternBase): else: await self.send_final_response( respond, streaming, response_text, already_streamed=streaming, + message_id=synthesis_msg_id, ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index f6af65c2..67ded823 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -15,6 +15,8 @@ from trustgraph.provenance import ( agent_iteration_uri, agent_thought_uri, agent_observation_uri, + agent_final_uri, + agent_decomposition_uri, ) from ..react.agent_manager import AgentManager @@ -47,9 +49,16 @@ class ReactPattern(PatternBase): # Emit session provenance on first iteration if iteration_num == 1: + # Subagents link back to the parent's decomposition + parent_session_id = getattr(request, 'parent_session_id', '') + parent_uri = ( + agent_decomposition_uri(parent_session_id) + if parent_session_id else None + ) await self.emit_session_triples( flow, session_uri, request.question, request.user, collection, respond, streaming, + parent_uri=parent_uri, ) logger.info(f"ReactPattern iteration {iteration_num}: {request.question}") @@ -60,11 +69,12 @@ class ReactPattern(PatternBase): # Compute URIs upfront for message_id thought_msg_id = agent_thought_uri(session_id, iteration_num) observation_msg_id = agent_observation_uri(session_id, iteration_num) + answer_msg_id = agent_final_uri(session_id) # Build callbacks think = self.make_think_callback(respond, streaming, message_id=thought_msg_id) observe = self.make_observe_callback(respond, streaming, message_id=observation_msg_id) - answer_cb = self.make_answer_callback(respond, streaming) + answer_cb = self.make_answer_callback(respond, streaming, message_id=answer_msg_id) # Filter tools filtered_tools = self.filter_tools( @@ -133,6 +143,7 @@ class ReactPattern(PatternBase): else: await self.send_final_response( respond, streaming, f, already_streamed=streaming, + message_id=answer_msg_id, ) return @@ -140,6 +151,7 @@ class ReactPattern(PatternBase): await self.emit_observation_triples( flow, session_id, iteration_num, act.observation, request, respond, + context=context, ) history.append(act) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index ed4c3983..ea0afd60 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -458,11 +458,14 @@ class Processor(AgentService): finding_index = len(entry["results"]) - 1 if entry else 0 collection = getattr(template, 'collection', 'default') + subagent_session_id = getattr(request, 'session_id', '') + await self.supervisor_pattern.emit_finding_triples( flow, parent_session_id, finding_index, subagent_goal, answer_text, template.user, collection, respond, template.streaming, + subagent_session_id=subagent_session_id, ) if all_done: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 951063cf..d5537876 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -16,7 +16,11 @@ import uuid from ... schema import AgentRequest, AgentResponse, AgentStep -from trustgraph.provenance import agent_finding_uri +from trustgraph.provenance import ( + agent_finding_uri, + agent_decomposition_uri, + agent_synthesis_uri, +) from . pattern_base import PatternBase @@ -81,7 +85,10 @@ class SupervisorPattern(PatternBase): session_uri, iteration_num): """Decompose the question into sub-goals and fan out subagents.""" - think = self.make_think_callback(respond, streaming) + decompose_msg_id = agent_decomposition_uri(session_id) + think = self.make_think_callback( + respond, streaming, message_id=decompose_msg_id, + ) framing = getattr(request, 'framing', '') tools = self.filter_tools(self.processor.agent.tools, request) @@ -171,7 +178,10 @@ class SupervisorPattern(PatternBase): session_uri, iteration_num): """Synthesise final answer from subagent results.""" - think = self.make_think_callback(respond, streaming) + synthesis_msg_id = agent_synthesis_uri(session_id) + think = self.make_think_callback( + respond, streaming, message_id=synthesis_msg_id, + ) framing = getattr(request, 'framing', '') # Collect subagent results from history @@ -205,17 +215,20 @@ class SupervisorPattern(PatternBase): }, respond=respond, streaming=streaming, + message_id=synthesis_msg_id, ) - # Emit synthesis provenance (links back to last finding) - last_finding_uri = agent_finding_uri( - session_id, len(subagent_results) - 1 - ) + # Emit synthesis provenance (links back to all findings) + finding_uris = [ + agent_finding_uri(session_id, i) + for i in range(len(subagent_results)) + ] await self.emit_synthesis_triples( - flow, session_id, last_finding_uri, + flow, session_id, finding_uris, response_text, request.user, collection, respond, streaming, ) await self.send_final_response( respond, streaming, response_text, already_streamed=streaming, + message_id=synthesis_msg_id, ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index af088ec9..0e783349 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -529,6 +529,8 @@ class Processor(AgentService): await respond(r) + answer_msg_id = agent_final_uri(session_id) + async def answer(x): logger.debug(f"Answer: {x}") @@ -539,6 +541,7 @@ class Processor(AgentService): content=x, end_of_message=False, end_of_dialog=False, + message_id=answer_msg_id, ) else: r = AgentResponse( @@ -546,6 +549,7 @@ class Processor(AgentService): content=x, end_of_message=True, end_of_dialog=False, + message_id=answer_msg_id, ) await respond(r) @@ -571,6 +575,7 @@ class Processor(AgentService): def __init__(self, flow, user): self._flow = flow self._user = user + self.last_sub_explain_uri = None def __call__(self, service_name): client = self._flow(service_name) @@ -635,13 +640,15 @@ class Processor(AgentService): explain_graph=GRAPH_RETRIEVAL, )) + user_context = UserAwareContext(flow, request.user) + act = await temp_agent.react( question = request.question, history = history, think = think, observe = observe, answer = answer, - context = UserAwareContext(flow, request.user), + context = user_context, streaming = streaming, on_action = on_action, ) @@ -717,6 +724,7 @@ class Processor(AgentService): content="", end_of_message=True, end_of_dialog=True, + message_id=answer_msg_id, ) else: r = AgentResponse( @@ -724,6 +732,7 @@ class Processor(AgentService): content=f, end_of_message=True, end_of_dialog=True, + message_id=answer_msg_id, ) await respond(r) @@ -737,6 +746,12 @@ class Processor(AgentService): # Emit standalone observation provenance (iteration was emitted in on_action) iteration_uri = agent_iteration_uri(session_id, iteration_num) observation_entity_uri = agent_observation_uri(session_id, iteration_num) + + # Derive from last sub-trace entity if available, else iteration + obs_parent_uri = iteration_uri + if user_context.last_sub_explain_uri: + obs_parent_uri = user_context.last_sub_explain_uri + observation_doc_id = None if act.observation: observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" @@ -755,7 +770,7 @@ class Processor(AgentService): obs_triples = set_graph( agent_observation_triples( observation_entity_uri, - iteration_uri, + obs_parent_uri, document_id=observation_doc_id, ), GRAPH_RETRIEVAL diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 86b515e1..041558ec 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -40,6 +40,7 @@ class KnowledgeQueryImpl: from ... schema import AgentResponse async def explain_callback(explain_id, explain_graph): + self.context.last_sub_explain_uri = explain_id await respond(AgentResponse( chunk_type="explain", content="", From 3ba6a3238fb1d4912596f20c8538c58631f69270 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Apr 2026 13:52:28 +0100 Subject: [PATCH 23/37] Misc test harnesses (#749) Some misc test harnesses for a few features --- dev-tools/tests/agent_dag/analyse_trace.py | 319 ++++++++++++++++++ dev-tools/tests/agent_dag/ws_capture.py | 81 +++++ .../tests/librarian/simple_text_download.py | 67 ++++ .../tests/librarian/simple_text_upload.py | 56 +++ dev-tools/tests/relay/test_rev_gateway.py | 237 +++++++++++++ dev-tools/tests/relay/websocket_relay.py | 210 ++++++++++++ 6 files changed, 970 insertions(+) create mode 100644 dev-tools/tests/agent_dag/analyse_trace.py create mode 100644 dev-tools/tests/agent_dag/ws_capture.py create mode 100644 dev-tools/tests/librarian/simple_text_download.py create mode 100644 dev-tools/tests/librarian/simple_text_upload.py create mode 100644 dev-tools/tests/relay/test_rev_gateway.py create mode 100644 dev-tools/tests/relay/websocket_relay.py diff --git a/dev-tools/tests/agent_dag/analyse_trace.py b/dev-tools/tests/agent_dag/analyse_trace.py new file mode 100644 index 00000000..b71cdebe --- /dev/null +++ b/dev-tools/tests/agent_dag/analyse_trace.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +""" +Analyse a captured agent trace JSON file and check DAG integrity. + +Usage: + python analyse_trace.py react.json + python analyse_trace.py -u http://localhost:8088/ react.json +""" + +import argparse +import asyncio +import json +import os +import sys +import websockets + +DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +DEFAULT_USER = "trustgraph" +DEFAULT_COLLECTION = "default" +DEFAULT_FLOW = "default" +GRAPH = "urn:graph:retrieval" + +# Namespace prefixes +PROV = "http://www.w3.org/ns/prov#" +RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#" +RDFS = "http://www.w3.org/2000/01/rdf-schema#" +TG = "https://trustgraph.ai/ns/" + +PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" +RDF_TYPE = RDF + "type" + +TG_ANALYSIS = TG + "Analysis" +TG_TOOL_USE = TG + "ToolUse" +TG_OBSERVATION_TYPE = TG + "Observation" +TG_CONCLUSION = TG + "Conclusion" +TG_SYNTHESIS = TG + "Synthesis" +TG_QUESTION = TG + "Question" + + +def shorten(uri): + """Shorten a URI for display.""" + for prefix, short in [ + (PROV, "prov:"), (RDF, "rdf:"), (RDFS, "rdfs:"), (TG, "tg:"), + ]: + if isinstance(uri, str) and uri.startswith(prefix): + return short + uri[len(prefix):] + return str(uri) + + +async def fetch_triples(ws, flow, subject, user, collection, request_counter): + """Query triples for a given subject URI.""" + request_counter[0] += 1 + req_id = f"q-{request_counter[0]}" + + msg = { + "id": req_id, + "service": "triples", + "flow": flow, + "request": { + "s": {"t": "i", "i": subject}, + "g": GRAPH, + "user": user, + "collection": collection, + "limit": 100, + }, + } + + await ws.send(json.dumps(msg)) + + while True: + raw = await ws.recv() + resp = json.loads(raw) + if resp.get("id") == req_id: + inner = resp.get("response", {}) + if isinstance(inner, dict): + return inner.get("response", []) + return inner + + +def extract_term(term): + """Extract value from wire-format term.""" + if not term: + return "" + t = term.get("t", "") + if t == "i": + return term.get("i", "") + elif t == "l": + return term.get("v", "") + elif t == "t": + tr = term.get("tr", {}) + return { + "s": extract_term(tr.get("s", {})), + "p": extract_term(tr.get("p", {})), + "o": extract_term(tr.get("o", {})), + } + return str(term) + + +def parse_triples(wire_triples): + """Convert wire triples to (s, p, o) tuples.""" + result = [] + for t in wire_triples: + s = extract_term(t.get("s", {})) + p = extract_term(t.get("p", {})) + o = extract_term(t.get("o", {})) + result.append((s, p, o)) + return result + + +def get_types(tuples): + """Get rdf:type values from parsed triples.""" + return {o for s, p, o in tuples if p == RDF_TYPE} + + +def get_derived_from(tuples): + """Get prov:wasDerivedFrom targets from parsed triples.""" + return [o for s, p, o in tuples if p == PROV_WAS_DERIVED_FROM] + + +async def analyse(path, url, flow, user, collection): + with open(path) as f: + messages = json.load(f) + + print(f"Total messages: {len(messages)}") + print() + + # ---- Pass 1: collect explain IDs and check streaming chunks ---- + + explain_ids = [] + errors = [] + + for i, msg in enumerate(messages): + resp = msg.get("response", {}) + chunk_type = resp.get("chunk_type", "?") + + if chunk_type == "explain": + explain_id = resp.get("explain_id", "") + explain_ids.append(explain_id) + print(f" {i:3d} {chunk_type} {explain_id}") + else: + print(f" {i:3d} {chunk_type}") + + # Rule 7: message_id on content chunks + if chunk_type in ("thought", "observation", "answer"): + mid = resp.get("message_id", "") + if not mid: + errors.append( + f"[msg {i}] {chunk_type} chunk missing message_id" + ) + + print() + print(f"Explain IDs ({len(explain_ids)}):") + for eid in explain_ids: + print(f" {eid}") + + # ---- Pass 2: fetch triples for each explain ID ---- + + ws_url = url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" + + request_counter = [0] + # entity_id -> parsed triples [(s, p, o), ...] + entities = {} + + print() + print("Fetching triples...") + print() + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as ws: + for eid in explain_ids: + wire = await fetch_triples( + ws, flow, eid, user, collection, request_counter, + ) + + tuples = parse_triples(wire) if isinstance(wire, list) else [] + entities[eid] = tuples + + print(f" {eid}") + for s, p, o in tuples: + o_short = str(o) + if len(o_short) > 80: + o_short = o_short[:77] + "..." + print(f" {shorten(p)} = {o_short}") + print() + + # ---- Pass 3: check rules ---- + + all_ids = set(entities.keys()) + + # Collect entity metadata + roots = [] # entities with no wasDerivedFrom + conclusions = [] # tg:Conclusion entities + analyses = [] # tg:Analysis entities + observations = [] # tg:Observation entities + + for eid, tuples in entities.items(): + types = get_types(tuples) + parents = get_derived_from(tuples) + + if not tuples: + errors.append(f"[{eid}] entity has no triples in store") + + if not parents: + roots.append(eid) + + if TG_CONCLUSION in types: + conclusions.append(eid) + if TG_ANALYSIS in types: + analyses.append(eid) + if TG_OBSERVATION_TYPE in types: + observations.append(eid) + + # Rule 4: every non-root entity has wasDerivedFrom + if parents: + for parent in parents: + # Rule 5: parent exists in known entities + if parent not in all_ids: + errors.append( + f"[{eid}] wasDerivedFrom target not in explain set: " + f"{parent}" + ) + + # Rule 6: Analysis entities must have ToolUse type + if TG_ANALYSIS in types and TG_TOOL_USE not in types: + errors.append( + f"[{eid}] Analysis entity missing tg:ToolUse type" + ) + + # Rule 1: exactly one root + if len(roots) == 0: + errors.append("No root entity found (all have wasDerivedFrom)") + elif len(roots) > 1: + errors.append( + f"Multiple roots ({len(roots)}) — expected exactly 1:" + ) + for r in roots: + types = get_types(entities[r]) + type_labels = ", ".join(shorten(t) for t in types) + errors.append(f" root: {r} [{type_labels}]") + + # Rule 2: exactly one terminal node (nothing derives from it) + # Build set of entities that are parents of something + has_children = set() + for eid, tuples in entities.items(): + for parent in get_derived_from(tuples): + has_children.add(parent) + + terminals = [eid for eid in all_ids if eid not in has_children] + if len(terminals) == 0: + errors.append("No terminal entity found (cycle?)") + elif len(terminals) > 1: + errors.append( + f"Multiple terminal entities ({len(terminals)}) — expected exactly 1:" + ) + for t in terminals: + types = get_types(entities[t]) + type_labels = ", ".join(shorten(ty) for ty in types) + errors.append(f" terminal: {t} [{type_labels}]") + + # Rule 8: Observation should not derive from Analysis if a sub-trace + # exists as a sibling. Check: if an Analysis has both a Question child + # and an Observation child, the Observation should derive from the + # sub-trace's Synthesis, not from the Analysis. + for obs_id in observations: + obs_parents = get_derived_from(entities[obs_id]) + for parent in obs_parents: + if parent in entities: + parent_types = get_types(entities[parent]) + if TG_ANALYSIS in parent_types: + # Check if this Analysis also has a Question child + # (i.e. a sub-trace exists) + has_subtrace = False + for other_id, other_tuples in entities.items(): + if other_id == obs_id: + continue + other_parents = get_derived_from(other_tuples) + other_types = get_types(other_tuples) + if (parent in other_parents + and TG_QUESTION in other_types): + has_subtrace = True + break + if has_subtrace: + errors.append( + f"[{obs_id}] Observation derives from Analysis " + f"{parent} which has a sub-trace — should derive " + f"from the sub-trace's Synthesis instead" + ) + + # ---- Report ---- + + print() + print("=" * 60) + if errors: + print(f"ERRORS ({len(errors)}):") + print() + for err in errors: + print(f" !! {err}") + else: + print("ALL CHECKS PASSED") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("input", help="JSON trace file") + parser.add_argument("-u", "--url", default=DEFAULT_URL) + parser.add_argument("-f", "--flow", default=DEFAULT_FLOW) + parser.add_argument("-U", "--user", default=DEFAULT_USER) + parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION) + args = parser.parse_args() + + asyncio.run(analyse( + args.input, args.url, args.flow, + args.user, args.collection, + )) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/tests/agent_dag/ws_capture.py b/dev-tools/tests/agent_dag/ws_capture.py new file mode 100644 index 00000000..3002d563 --- /dev/null +++ b/dev-tools/tests/agent_dag/ws_capture.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" +Connect to TrustGraph websocket, run an agent query, capture all +response messages to a JSON file. + +Usage: + python ws_capture.py -q "What is the document about?" -o trace.json + python ws_capture.py -q "..." -u http://localhost:8088/ -o out.json +""" + +import argparse +import asyncio +import json +import os +import websockets + +DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +DEFAULT_USER = "trustgraph" +DEFAULT_COLLECTION = "default" +DEFAULT_FLOW = "default" + + +async def capture(url, flow, question, user, collection, output): + + # Convert to ws URL + ws_url = url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=120) as ws: + + request = { + "id": "capture", + "service": "agent", + "flow": flow, + "request": { + "question": question, + "user": user, + "collection": collection, + "streaming": True, + }, + } + + await ws.send(json.dumps(request)) + + messages = [] + + async for raw in ws: + msg = json.loads(raw) + + if msg.get("id") != "capture": + continue + + messages.append(msg) + + if msg.get("complete"): + break + + with open(output, "w") as f: + json.dump(messages, f, indent=2) + + print(f"Captured {len(messages)} messages to {output}") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("-q", "--question", required=True) + parser.add_argument("-o", "--output", default="trace.json") + parser.add_argument("-u", "--url", default=DEFAULT_URL) + parser.add_argument("-f", "--flow", default=DEFAULT_FLOW) + parser.add_argument("-U", "--user", default=DEFAULT_USER) + parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION) + args = parser.parse_args() + + asyncio.run(capture( + args.url, args.flow, args.question, + args.user, args.collection, args.output, + )) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/tests/librarian/simple_text_download.py b/dev-tools/tests/librarian/simple_text_download.py new file mode 100644 index 00000000..6af2a60d --- /dev/null +++ b/dev-tools/tests/librarian/simple_text_download.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Minimal example: download a text document in tiny chunks via websocket API +""" + +import asyncio +import json +import base64 +import websockets + +async def main(): + url = "ws://localhost:8088/api/v1/socket" + + document_id = "test-chunked-doc-001" + chunk_size = 10 # Tiny chunks! + + request_id = 0 + + async def send_request(ws, request): + nonlocal request_id + request_id += 1 + msg = { + "id": f"req-{request_id}", + "service": "librarian", + "request": request + } + await ws.send(json.dumps(msg)) + response = json.loads(await ws.recv()) + if "error" in response: + raise Exception(response["error"]) + return response.get("response", {}) + + async with websockets.connect(url) as ws: + + print(f"Fetching document: {document_id}") + print(f"Chunk size: {chunk_size} bytes") + print() + + chunk_index = 0 + all_content = b"" + + while True: + resp = await send_request(ws, { + "operation": "stream-document", + "user": "trustgraph", + "document-id": document_id, + "chunk-index": chunk_index, + "chunk-size": chunk_size, + }) + + chunk_data = base64.b64decode(resp["content"]) + total_chunks = resp["total-chunks"] + total_bytes = resp["total-bytes"] + + print(f"Chunk {chunk_index}: {chunk_data}") + + all_content += chunk_data + chunk_index += 1 + + if chunk_index >= total_chunks: + break + + print() + print(f"Complete: {all_content}") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev-tools/tests/librarian/simple_text_upload.py b/dev-tools/tests/librarian/simple_text_upload.py new file mode 100644 index 00000000..e21bd185 --- /dev/null +++ b/dev-tools/tests/librarian/simple_text_upload.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +""" +Minimal example: upload a small text document via websocket API +""" + +import asyncio +import json +import base64 +import time +import websockets + +async def main(): + url = "ws://localhost:8088/api/v1/socket" + + # Small text content + content = b"AAAAAAAAAABBBBBBBBBBCCCCCCCCCC" + + request_id = 0 + + async def send_request(ws, request): + nonlocal request_id + request_id += 1 + msg = { + "id": f"req-{request_id}", + "service": "librarian", + "request": request + } + await ws.send(json.dumps(msg)) + response = json.loads(await ws.recv()) + if "error" in response: + raise Exception(response["error"]) + return response.get("response", {}) + + async with websockets.connect(url) as ws: + + print(f"Uploading {len(content)} bytes...") + + resp = await send_request(ws, { + "operation": "add-document", + "document-metadata": { + "id": "test-chunked-doc-001", + "time": int(time.time()), + "kind": "text/plain", + "title": "My Test Document", + "comments": "Small doc for chunk testing", + "user": "trustgraph", + "tags": ["test"], + "metadata": [], + }, + "content": base64.b64encode(content).decode("utf-8"), + }) + + print("Done!") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev-tools/tests/relay/test_rev_gateway.py b/dev-tools/tests/relay/test_rev_gateway.py new file mode 100644 index 00000000..fe200e46 --- /dev/null +++ b/dev-tools/tests/relay/test_rev_gateway.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +WebSocket Test Client + +A simple client to test the reverse gateway through the relay. +Connects to the relay's /in endpoint and allows sending test messages. + +Usage: + python test_client.py [--uri URI] [--interactive] +""" + +import asyncio +import json +import logging +import argparse +import uuid +from aiohttp import ClientSession, WSMsgType + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("test_client") + +class TestClient: + """Simple WebSocket test client""" + + def __init__(self, uri: str): + self.uri = uri + self.session = None + self.ws = None + self.running = False + self.message_counter = 0 + self.client_id = str(uuid.uuid4())[:8] + + async def connect(self): + """Connect to the WebSocket""" + self.session = ClientSession() + logger.info(f"Connecting to {self.uri}") + self.ws = await self.session.ws_connect(self.uri) + logger.info("Connected successfully") + + async def disconnect(self): + """Disconnect from WebSocket""" + if self.ws and not self.ws.closed: + await self.ws.close() + if self.session and not self.session.closed: + await self.session.close() + logger.info("Disconnected") + + async def send_message(self, service: str, request_data: dict, flow: str = "default"): + """Send a properly formatted TrustGraph message""" + self.message_counter += 1 + message = { + "id": f"{self.client_id}-{self.message_counter}", + "service": service, + "request": request_data, + "flow": flow + } + + message_json = json.dumps(message, indent=2) + logger.info(f"Sending message:\n{message_json}") + await self.ws.send_str(json.dumps(message)) + + async def listen_for_responses(self): + """Listen for incoming messages""" + logger.info("Listening for responses...") + + async for msg in self.ws: + if msg.type == WSMsgType.TEXT: + try: + response = json.loads(msg.data) + logger.info(f"Received response:\n{json.dumps(response, indent=2)}") + except json.JSONDecodeError: + logger.info(f"Received text: {msg.data}") + elif msg.type == WSMsgType.BINARY: + logger.info(f"Received binary data: {len(msg.data)} bytes") + elif msg.type == WSMsgType.ERROR: + logger.error(f"WebSocket error: {self.ws.exception()}") + break + else: + logger.info(f"Connection closed: {msg.type}") + break + + async def interactive_mode(self): + """Interactive mode for manual testing""" + print("\n=== Interactive Test Client ===") + print("Available commands:") + print(" text-completion - Test text completion service") + print(" agent - Test agent service") + print(" embeddings - Test embeddings service") + print(" custom - Send custom message") + print(" quit - Exit") + print() + + # Start response listener + listen_task = asyncio.create_task(self.listen_for_responses()) + + try: + while True: + try: + command = input("Command> ").strip().lower() + + if command == "quit": + break + elif command == "text-completion": + await self.send_message("text-completion", { + "system": "You are a helpful assistant.", + "prompt": "What is 2+2?" + }) + elif command == "agent": + await self.send_message("agent", { + "question": "What is the capital of France?" + }) + elif command == "embeddings": + await self.send_message("embeddings", { + "text": "Hello world" + }) + elif command == "custom": + service = input("Service name> ").strip() + request_json = input("Request JSON> ").strip() + try: + request_data = json.loads(request_json) + await self.send_message(service, request_data) + except json.JSONDecodeError as e: + print(f"Invalid JSON: {e}") + elif command == "": + continue + else: + print(f"Unknown command: {command}") + + except KeyboardInterrupt: + break + except EOFError: + break + except Exception as e: + logger.error(f"Error in interactive mode: {e}") + + finally: + listen_task.cancel() + try: + await listen_task + except asyncio.CancelledError: + pass + + async def run_predefined_tests(self): + """Run a series of predefined tests""" + print("\n=== Running Predefined Tests ===") + + # Start response listener + listen_task = asyncio.create_task(self.listen_for_responses()) + + try: + # Test 1: Text completion + print("\n1. Testing text-completion service...") + await self.send_message("text-completion", { + "system": "You are a helpful assistant.", + "prompt": "What is 2+2?" + }) + await asyncio.sleep(2) + + # Test 2: Agent + print("\n2. Testing agent service...") + await self.send_message("agent", { + "question": "What is the capital of France?" + }) + await asyncio.sleep(2) + + # Test 3: Embeddings + print("\n3. Testing embeddings service...") + await self.send_message("embeddings", { + "text": "Hello world" + }) + await asyncio.sleep(2) + + # Test 4: Invalid service + print("\n4. Testing invalid service...") + await self.send_message("nonexistent-service", { + "test": "data" + }) + await asyncio.sleep(2) + + print("\nTests completed. Waiting for any remaining responses...") + await asyncio.sleep(3) + + finally: + listen_task.cancel() + try: + await listen_task + except asyncio.CancelledError: + pass + +async def main(): + parser = argparse.ArgumentParser( + description="WebSocket Test Client for Reverse Gateway" + ) + parser.add_argument( + '--uri', + default='ws://localhost:8080/in', + help='WebSocket URI to connect to (default: ws://localhost:8080/in)' + ) + parser.add_argument( + '--interactive', '-i', + action='store_true', + help='Run in interactive mode' + ) + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Enable verbose logging' + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + client = TestClient(args.uri) + + try: + await client.connect() + + if args.interactive: + await client.interactive_mode() + else: + await client.run_predefined_tests() + + except KeyboardInterrupt: + print("\nShutdown requested by user") + except Exception as e: + logger.error(f"Client error: {e}") + finally: + await client.disconnect() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/dev-tools/tests/relay/websocket_relay.py b/dev-tools/tests/relay/websocket_relay.py new file mode 100644 index 00000000..d537f7da --- /dev/null +++ b/dev-tools/tests/relay/websocket_relay.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" +WebSocket Relay Test Harness + +This script creates a relay server with two WebSocket endpoints: +- /in - for test clients to connect to +- /out - for reverse gateway to connect to + +Messages are bidirectionally relayed between the two connections. + +Usage: + python websocket_relay.py [--port PORT] [--host HOST] +""" + +import asyncio +import logging +import argparse +from aiohttp import web, WSMsgType +import weakref +from typing import Optional, Set + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("websocket_relay") + +class WebSocketRelay: + """WebSocket relay that forwards messages between 'in' and 'out' connections""" + + def __init__(self): + self.in_connections: Set = weakref.WeakSet() + self.out_connections: Set = weakref.WeakSet() + + async def handle_in_connection(self, request): + """Handle incoming connections on /in endpoint""" + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.in_connections.add(ws) + logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + data = msg.data + logger.info(f"IN → OUT: {data}") + await self._forward_to_out(data) + elif msg.type == WSMsgType.BINARY: + data = msg.data + logger.info(f"IN → OUT: {len(data)} bytes (binary)") + await self._forward_to_out(data, binary=True) + elif msg.type == WSMsgType.ERROR: + logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") + break + else: + break + except Exception as e: + logger.error(f"Error in 'in' connection handler: {e}") + finally: + logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + return ws + + async def handle_out_connection(self, request): + """Handle outgoing connections on /out endpoint""" + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.out_connections.add(ws) + logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + data = msg.data + logger.info(f"OUT → IN: {data}") + await self._forward_to_in(data) + elif msg.type == WSMsgType.BINARY: + data = msg.data + logger.info(f"OUT → IN: {len(data)} bytes (binary)") + await self._forward_to_in(data, binary=True) + elif msg.type == WSMsgType.ERROR: + logger.error(f"WebSocket error on 'out' connection: {ws.exception()}") + break + else: + break + except Exception as e: + logger.error(f"Error in 'out' connection handler: {e}") + finally: + logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") + + return ws + + async def _forward_to_out(self, data, binary=False): + """Forward message from 'in' to all 'out' connections""" + if not self.out_connections: + logger.warning("No 'out' connections available to forward message") + return + + closed_connections = [] + for ws in list(self.out_connections): + try: + if ws.closed: + closed_connections.append(ws) + continue + + if binary: + await ws.send_bytes(data) + else: + await ws.send_str(data) + except Exception as e: + logger.error(f"Error forwarding to 'out' connection: {e}") + closed_connections.append(ws) + + # Clean up closed connections + for ws in closed_connections: + if ws in self.out_connections: + self.out_connections.discard(ws) + + async def _forward_to_in(self, data, binary=False): + """Forward message from 'out' to all 'in' connections""" + if not self.in_connections: + logger.warning("No 'in' connections available to forward message") + return + + closed_connections = [] + for ws in list(self.in_connections): + try: + if ws.closed: + closed_connections.append(ws) + continue + + if binary: + await ws.send_bytes(data) + else: + await ws.send_str(data) + except Exception as e: + logger.error(f"Error forwarding to 'in' connection: {e}") + closed_connections.append(ws) + + # Clean up closed connections + for ws in closed_connections: + if ws in self.in_connections: + self.in_connections.discard(ws) + +async def create_app(relay): + """Create the web application with routes""" + app = web.Application() + + # Add routes + app.router.add_get('/in', relay.handle_in_connection) + app.router.add_get('/out', relay.handle_out_connection) + + # Add a simple status endpoint + async def status(request): + status_info = { + 'in_connections': len(relay.in_connections), + 'out_connections': len(relay.out_connections), + 'status': 'running' + } + return web.json_response(status_info) + + app.router.add_get('/status', status) + app.router.add_get('/', status) # Root also shows status + + return app + +def main(): + parser = argparse.ArgumentParser( + description="WebSocket Relay Test Harness" + ) + parser.add_argument( + '--host', + default='localhost', + help='Host to bind to (default: localhost)' + ) + parser.add_argument( + '--port', + type=int, + default=8080, + help='Port to bind to (default: 8080)' + ) + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Enable verbose logging' + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + relay = WebSocketRelay() + + print(f"Starting WebSocket Relay on {args.host}:{args.port}") + print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") + print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") + print(f" Status: http://{args.host}:{args.port}/status") + print() + print("Usage:") + print(f" Test client connects to: ws://{args.host}:{args.port}/in") + print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") + + web.run_app(create_app(relay), host=args.host, port=args.port) + +if __name__ == "__main__": + main() \ No newline at end of file From dbf8daa74af86c26c80eac74d6cd4c490e856ed0 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Apr 2026 13:59:34 +0100 Subject: [PATCH 24/37] Additional agent DAG tests (#750) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_agent_provenance.py: test_session_parent_uri, test_session_no_parent_uri, and 6 synthesis tests (types, single/multiple parents, document, label) - test_on_action_callback.py: 3 tests — fires before tool, skipped for Final, works when None - test_callback_message_id.py: 7 tests — message_id on think/observe/ answer callbacks (streaming + non-streaming) and send_final_response - test_parse_chunk_message_id.py (5 tests) - _parse_chunk propagates message_id for thought, observation, answer; handles missing gracefully - test_explainability_parsing.py (+1) - test_dispatches_analysis_with_tooluse - Analysis+ToolUse mixin still dispatches to Analysis - test_explainability.py (+1) - test_observation_found_via_subtrace_synthesis - chain walker follows from sub-trace Synthesis to find Observation and Conclusion in correct order - test_agent_provenance.py (+8) - session parent_uri (2), synthesis single/multiple parents, types, document, label (6) --- dev-tools/tests/triples/load_test_triples.py | 227 ++++++++++++++++++ .../test_agent/test_callback_message_id.py | 122 ++++++++++ .../test_agent/test_explainability_parsing.py | 8 + .../test_agent/test_on_action_callback.py | 132 ++++++++++ .../test_agent/test_parse_chunk_message_id.py | 74 ++++++ .../test_provenance/test_agent_provenance.py | 78 +++++- .../test_provenance/test_explainability.py | 93 +++++++ 7 files changed, 733 insertions(+), 1 deletion(-) create mode 100755 dev-tools/tests/triples/load_test_triples.py create mode 100644 tests/unit/test_agent/test_callback_message_id.py create mode 100644 tests/unit/test_agent/test_on_action_callback.py create mode 100644 tests/unit/test_agent/test_parse_chunk_message_id.py diff --git a/dev-tools/tests/triples/load_test_triples.py b/dev-tools/tests/triples/load_test_triples.py new file mode 100755 index 00000000..a147d041 --- /dev/null +++ b/dev-tools/tests/triples/load_test_triples.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Load test triples into the triple store for testing tg-query-graph. + +Tests all graph features: +- SPO with IRI objects +- SPO with literal objects +- Literals with XML datatypes +- Literals with language tags +- Quoted triples (RDF-star) +- Named graphs +""" + +import asyncio +import json +import os +import websockets + +# Configuration +API_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None) +FLOW = "default" +USER = "trustgraph" +COLLECTION = "default" +DOCUMENT_ID = "test-triples-001" + +# Namespaces +EX = "http://example.org/" +RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#" +RDFS = "http://www.w3.org/2000/01/rdf-schema#" +XSD = "http://www.w3.org/2001/XMLSchema#" +TG = "https://trustgraph.ai/ns/" + + +def iri(value): + """Build IRI term.""" + return {"t": "i", "i": value} + + +def literal(value, datatype=None, language=None): + """Build literal term with optional datatype or language.""" + term = {"t": "l", "v": value} + if datatype: + term["dt"] = datatype + if language: + term["ln"] = language + return term + + +def quoted_triple(s, p, o): + """Build quoted triple term (RDF-star).""" + return { + "t": "t", + "tr": {"s": s, "p": p, "o": o} + } + + +def triple(s, p, o, g=None): + """Build a complete triple dict.""" + t = {"s": s, "p": p, "o": o} + if g: + t["g"] = g + return t + + +# Test triples covering all features +TEST_TRIPLES = [ + # 1. Basic SPO with IRI object + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDF}type"), + iri(f"{EX}Scientist") + ), + + # 2. SPO with IRI object (relationship) + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}discovered"), + iri(f"{EX}radium") + ), + + # 3. Simple literal (no datatype/language) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie") + ), + + # 4. Literal with language tag (English) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie", language="en") + ), + + # 5. Literal with language tag (French) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie", language="fr") + ), + + # 6. Literal with language tag (Polish) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Maria Sk\u0142odowska-Curie", language="pl") + ), + + # 7. Literal with xsd:integer datatype + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}birthYear"), + literal("1867", datatype=f"{XSD}integer") + ), + + # 8. Literal with xsd:date datatype + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}birthDate"), + literal("1867-11-07", datatype=f"{XSD}date") + ), + + # 9. Literal with xsd:boolean datatype + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}nobelLaureate"), + literal("true", datatype=f"{XSD}boolean") + ), + + # 10. Quoted triple in object position (RDF 1.2 style) + # "Wikipedia asserts that Marie Curie discovered radium" + triple( + iri(f"{EX}wikipedia"), + iri(f"{TG}asserts"), + quoted_triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}discovered"), + iri(f"{EX}radium") + ) + ), + + # 11. Quoted triple with literal inside (object position) + # "NLP-v1.0 extracted that Marie Curie has label Marie Curie" + triple( + iri(f"{EX}nlp-v1"), + iri(f"{TG}extracted"), + quoted_triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie") + ) + ), + + # 12. Triple in a named graph (g is plain string, not Term) + triple( + iri(f"{EX}radium"), + iri(f"{RDF}type"), + iri(f"{EX}Element"), + g=f"{EX}chemistry-graph" + ), + + # 13. Another triple in the same named graph + triple( + iri(f"{EX}radium"), + iri(f"{EX}atomicNumber"), + literal("88", datatype=f"{XSD}integer"), + g=f"{EX}chemistry-graph" + ), + + # 14. Triple in a different named graph + triple( + iri(f"{EX}pierre-curie"), + iri(f"{EX}spouseOf"), + iri(f"{EX}marie-curie"), + g=f"{EX}biography-graph" + ), +] + + +async def load_triples(): + """Load test triples via WebSocket bulk import.""" + # Convert HTTP URL to WebSocket URL + ws_url = API_URL.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url.rstrip('/')}/api/v1/flow/{FLOW}/import/triples" + if TOKEN: + ws_url = f"{ws_url}?token={TOKEN}" + + metadata = { + "id": DOCUMENT_ID, + "metadata": [], + "user": USER, + "collection": COLLECTION + } + + print(f"Connecting to {ws_url}...") + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as websocket: + message = { + "metadata": metadata, + "triples": TEST_TRIPLES + } + print(f"Sending {len(TEST_TRIPLES)} test triples...") + await websocket.send(json.dumps(message)) + print("Triples sent successfully!") + + print("\nTest triples loaded:") + print(" - 2 basic IRI triples (type, relationship)") + print(" - 4 literal triples (plain + 3 languages: en, fr, pl)") + print(" - 3 typed literal triples (xsd:integer, xsd:date, xsd:boolean)") + print(" - 2 quoted triples (RDF-star provenance)") + print(" - 3 triples in named graphs (chemistry-graph, biography-graph)") + print(f"\nTotal: {len(TEST_TRIPLES)} triples") + print(f"User: {USER}, Collection: {COLLECTION}") + + +def main(): + print("Loading test triples for tg-query-graph testing\n") + asyncio.run(load_triples()) + print("\nDone! Now test with:") + print(" tg-query-graph -s http://example.org/marie-curie") + print(" tg-query-graph -p http://www.w3.org/2000/01/rdf-schema#label") + print(" tg-query-graph -o 'Marie Curie' --object-language en") + print(" tg-query-graph --format json | jq .") + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_agent/test_callback_message_id.py b/tests/unit/test_agent/test_callback_message_id.py new file mode 100644 index 00000000..7cb0ee54 --- /dev/null +++ b/tests/unit/test_agent/test_callback_message_id.py @@ -0,0 +1,122 @@ +""" +Tests that streaming callbacks set message_id on AgentResponse. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.agent.orchestrator.pattern_base import PatternBase +from trustgraph.schema import AgentResponse + + +@pytest.fixture +def pattern(): + processor = MagicMock() + return PatternBase(processor) + + +class TestThinkCallbackMessageId: + + @pytest.mark.asyncio + async def test_streaming_think_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/i1/thought" + think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id) + await think("hello", is_final=False) + + assert len(responses) == 1 + assert responses[0].message_id == msg_id + assert responses[0].chunk_type == "thought" + + @pytest.mark.asyncio + async def test_non_streaming_think_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/i1/thought" + think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id) + await think("hello") + + assert responses[0].message_id == msg_id + assert responses[0].end_of_message is True + + +class TestObserveCallbackMessageId: + + @pytest.mark.asyncio + async def test_streaming_observe_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/i1/observation" + observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id) + await observe("result", is_final=True) + + assert responses[0].message_id == msg_id + assert responses[0].chunk_type == "observation" + + +class TestAnswerCallbackMessageId: + + @pytest.mark.asyncio + async def test_streaming_answer_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/final" + answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id) + await answer("the answer") + + assert responses[0].message_id == msg_id + assert responses[0].chunk_type == "answer" + + @pytest.mark.asyncio + async def test_no_message_id_default(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + answer = pattern.make_answer_callback(capture, streaming=True) + await answer("the answer") + + assert responses[0].message_id == "" + + +class TestSendFinalResponseMessageId: + + @pytest.mark.asyncio + async def test_streaming_final_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/final" + await pattern.send_final_response( + capture, streaming=True, answer_text="answer", + message_id=msg_id, + ) + + # Should get content chunk + end-of-dialog marker + assert all(r.message_id == msg_id for r in responses) + + @pytest.mark.asyncio + async def test_non_streaming_final_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/final" + await pattern.send_final_response( + capture, streaming=False, answer_text="answer", + message_id=msg_id, + ) + + assert len(responses) == 1 + assert responses[0].message_id == msg_id + assert responses[0].end_of_dialog is True diff --git a/tests/unit/test_agent/test_explainability_parsing.py b/tests/unit/test_agent/test_explainability_parsing.py index 7035318d..d75ea604 100644 --- a/tests/unit/test_agent/test_explainability_parsing.py +++ b/tests/unit/test_agent/test_explainability_parsing.py @@ -22,6 +22,7 @@ from trustgraph.api.explainability import ( TG_SYNTHESIS, TG_ANSWER_TYPE, TG_OBSERVATION_TYPE, + TG_TOOL_USE, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, @@ -76,6 +77,13 @@ class TestFromTriplesDispatch: entity = ExplainEntity.from_triples("urn:a", triples) assert isinstance(entity, Analysis) + def test_dispatches_analysis_with_tooluse(self): + """Analysis+ToolUse mixin still dispatches to Analysis.""" + triples = _make_triples("urn:a", + [PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE]) + entity = ExplainEntity.from_triples("urn:a", triples) + assert isinstance(entity, Analysis) + def test_dispatches_observation(self): triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE]) entity = ExplainEntity.from_triples("urn:o", triples) diff --git a/tests/unit/test_agent/test_on_action_callback.py b/tests/unit/test_agent/test_on_action_callback.py new file mode 100644 index 00000000..4a1c0c3b --- /dev/null +++ b/tests/unit/test_agent/test_on_action_callback.py @@ -0,0 +1,132 @@ +""" +Tests for the on_action callback in react() — verifies that it fires +after action selection but before tool execution. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.agent.react.agent_manager import AgentManager +from trustgraph.agent.react.types import Action, Final, Tool, Argument + + +class TestOnActionCallback: + + @pytest.mark.asyncio + async def test_on_action_called_for_tool_use(self): + """on_action fires when react() selects a tool (not Final).""" + call_log = [] + + async def fake_on_action(act): + call_log.append(("on_action", act.name)) + + # Tool that records when it's invoked + async def tool_invoke(**kwargs): + call_log.append(("tool_invoke",)) + return "tool result" + + tool_impl = MagicMock() + tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke) + + tools = { + "search": Tool( + name="search", + description="Search", + implementation=tool_impl, + arguments=[Argument(name="query", type="string", description="q")], + config={}, + ), + } + + agent = AgentManager(tools=tools) + + # Mock reason() to return an Action + action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="") + agent.reason = AsyncMock(return_value=action) + + think = AsyncMock() + observe = AsyncMock() + context = MagicMock() + + await agent.react( + question="test", + history=[], + think=think, + observe=observe, + context=context, + on_action=fake_on_action, + ) + + # on_action should fire before tool_invoke + assert len(call_log) == 2 + assert call_log[0] == ("on_action", "search") + assert call_log[1] == ("tool_invoke",) + + @pytest.mark.asyncio + async def test_on_action_not_called_for_final(self): + """on_action does not fire when react() returns Final.""" + called = [] + + async def fake_on_action(act): + called.append(act) + + agent = AgentManager(tools={}) + agent.reason = AsyncMock( + return_value=Final(thought="done", final="answer") + ) + + think = AsyncMock() + observe = AsyncMock() + context = MagicMock() + + result = await agent.react( + question="test", + history=[], + think=think, + observe=observe, + context=context, + on_action=fake_on_action, + ) + + assert isinstance(result, Final) + assert len(called) == 0 + + @pytest.mark.asyncio + async def test_on_action_none_accepted(self): + """react() works fine when on_action is None (default).""" + async def tool_invoke(**kwargs): + return "result" + + tool_impl = MagicMock() + tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke) + + tools = { + "search": Tool( + name="search", + description="Search", + implementation=tool_impl, + arguments=[], + config={}, + ), + } + + agent = AgentManager(tools=tools) + agent.reason = AsyncMock( + return_value=Action(thought="t", name="search", arguments={}, observation="") + ) + + think = AsyncMock() + observe = AsyncMock() + context = MagicMock() + + result = await agent.react( + question="test", + history=[], + think=think, + observe=observe, + context=context, + # on_action not passed — defaults to None + ) + + assert isinstance(result, Action) + assert result.observation == "result" diff --git a/tests/unit/test_agent/test_parse_chunk_message_id.py b/tests/unit/test_agent/test_parse_chunk_message_id.py new file mode 100644 index 00000000..38942f1e --- /dev/null +++ b/tests/unit/test_agent/test_parse_chunk_message_id.py @@ -0,0 +1,74 @@ +""" +Tests that _parse_chunk propagates message_id from wire format +to AgentThought, AgentObservation, and AgentAnswer. +""" + +import pytest + +from trustgraph.api.socket_client import SocketClient +from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer + + +@pytest.fixture +def client(): + # We only need _parse_chunk — don't connect + c = object.__new__(SocketClient) + return c + + +class TestParseChunkMessageId: + + def test_thought_message_id(self, client): + resp = { + "chunk_type": "thought", + "content": "thinking...", + "end_of_message": False, + "message_id": "urn:trustgraph:agent:sess/i1/thought", + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentThought) + assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought" + + def test_observation_message_id(self, client): + resp = { + "chunk_type": "observation", + "content": "result", + "end_of_message": True, + "message_id": "urn:trustgraph:agent:sess/i1/observation", + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentObservation) + assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation" + + def test_answer_message_id(self, client): + resp = { + "chunk_type": "answer", + "content": "the answer", + "end_of_message": False, + "end_of_dialog": False, + "message_id": "urn:trustgraph:agent:sess/final", + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentAnswer) + assert chunk.message_id == "urn:trustgraph:agent:sess/final" + + def test_thought_missing_message_id(self, client): + resp = { + "chunk_type": "thought", + "content": "thinking...", + "end_of_message": False, + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentThought) + assert chunk.message_id == "" + + def test_answer_missing_message_id(self, client): + resp = { + "chunk_type": "answer", + "content": "answer", + "end_of_message": True, + "end_of_dialog": True, + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentAnswer) + assert chunk.message_id == "" diff --git a/tests/unit/test_provenance/test_agent_provenance.py b/tests/unit/test_provenance/test_agent_provenance.py index d3f0ef8c..c548ef9d 100644 --- a/tests/unit/test_provenance/test_agent_provenance.py +++ b/tests/unit/test_provenance/test_agent_provenance.py @@ -12,6 +12,7 @@ from trustgraph.provenance.agent import ( agent_iteration_triples, agent_observation_triples, agent_final_triples, + agent_synthesis_triples, ) from trustgraph.provenance.namespaces import ( @@ -21,7 +22,7 @@ from trustgraph.provenance.namespaces import ( TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, - TG_TOOL_USE, + TG_TOOL_USE, TG_SYNTHESIS, TG_AGENT_QUESTION, ) @@ -105,6 +106,25 @@ class TestAgentSessionTriples: ) assert len(triples) == 6 + def test_session_parent_uri(self): + """Subagent sessions derive from a parent entity (e.g. Decomposition).""" + parent = "urn:trustgraph:agent:parent/decompose" + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z", + parent_uri=parent, + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI) + assert derived is not None + assert derived.o.iri == parent + + def test_session_no_parent_uri(self): + """Top-level sessions have no wasDerivedFrom.""" + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z" + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI) + assert derived is None + # --------------------------------------------------------------------------- # agent_iteration_triples @@ -358,3 +378,59 @@ class TestAgentFinalTriples: ) doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) assert doc is None + + +# --------------------------------------------------------------------------- +# agent_synthesis_triples +# --------------------------------------------------------------------------- + +class TestAgentSynthesisTriples: + + SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis" + FINDING_0 = "urn:trustgraph:agent:test-session/finding/0" + FINDING_1 = "urn:trustgraph:agent:test-session/finding/1" + FINDING_2 = "urn:trustgraph:agent:test-session/finding/2" + + def test_synthesis_types(self): + triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0) + assert has_type(triples, self.SYNTH_URI, PROV_ENTITY) + assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS) + assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE) + + def test_synthesis_single_parent_string(self): + """Single parent passed as string.""" + triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0) + derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI) + assert len(derived) == 1 + assert derived[0].o.iri == self.FINDING_0 + + def test_synthesis_multiple_parents(self): + """Multiple parents for supervisor fan-in.""" + parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2] + triples = agent_synthesis_triples(self.SYNTH_URI, parents) + derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI) + assert len(derived) == 3 + derived_uris = {t.o.iri for t in derived} + assert derived_uris == set(parents) + + def test_synthesis_single_parent_as_list(self): + """Single parent passed as list.""" + triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0]) + derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI) + assert len(derived) == 1 + assert derived[0].o.iri == self.FINDING_0 + + def test_synthesis_document(self): + triples = agent_synthesis_triples( + self.SYNTH_URI, self.FINDING_0, + document_id="urn:doc:synth", + ) + doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI) + assert doc is not None + assert doc.o.iri == "urn:doc:synth" + + def test_synthesis_label(self): + triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0) + label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI) + assert label is not None + assert label.o.value == "Synthesis" diff --git a/tests/unit/test_provenance/test_explainability.py b/tests/unit/test_provenance/test_explainability.py index e2c7fcd1..a6d655a7 100644 --- a/tests/unit/test_provenance/test_explainability.py +++ b/tests/unit/test_provenance/test_explainability.py @@ -558,3 +558,96 @@ class TestExplainabilityClientDetectSessionType: mock_flow = MagicMock() client = ExplainabilityClient(mock_flow, retry_delay=0.0) assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag" + + +class TestChainWalkerFollowsSubTraceTerminal: + """Test that _follow_provenance_chain continues from a sub-trace's + Synthesis to find downstream entities like Observation.""" + + def test_observation_found_via_subtrace_synthesis(self): + """ + DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation + The walker should find Analysis, the sub-trace, then follow from + Synthesis to discover Observation. + """ + # Entity triples (s, p, o) + entity_data = { + "urn:agent:q": [ + ("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION), + ("urn:agent:q", TG_QUERY, "test"), + ], + "urn:agent:analysis": [ + ("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS), + ("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"), + ], + "urn:graphrag:q": [ + ("urn:graphrag:q", RDF_TYPE, TG_QUESTION), + ("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:graphrag:q", TG_QUERY, "test"), + ("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"), + ], + "urn:graphrag:synth": [ + ("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS), + ("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"), + ], + "urn:agent:obs": [ + ("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE), + ("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"), + ], + "urn:agent:conclusion": [ + ("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION), + ("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"), + ], + } + + # Build a mock flow that answers triples queries + # Query by s= returns that entity's triples + # Query by p=wasDerivedFrom, o=X returns entities derived from X + def mock_triples_query(s=None, p=None, o=None, **kwargs): + if s and not p: + # Fetch entity triples + tuples = entity_data.get(s, []) + return _make_wire_triples(tuples) + elif p == PROV_WAS_DERIVED_FROM and o: + # Find entities derived from o + results = [] + for uri, tuples in entity_data.items(): + for _, pred, obj in tuples: + if pred == PROV_WAS_DERIVED_FROM and obj == o: + results.append((uri, pred, obj)) + return _make_wire_triples(results) + return [] + + mock_flow = MagicMock() + mock_flow.triples_query.side_effect = mock_triples_query + + client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2) + + # Mock fetch_graphrag_trace to return a trace with a synthesis + synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis") + client.fetch_graphrag_trace = MagicMock(return_value={ + "question": Question(uri="urn:graphrag:q", entity_type="question", + question_type="graph-rag"), + "synthesis": synth_entity, + }) + + trace = client.fetch_agent_trace( + "urn:agent:q", + graph="urn:graph:retrieval", + ) + + # Should have found all steps + step_types = [ + type(s).__name__ if not isinstance(s, dict) else s.get("type") + for s in trace["steps"] + ] + + assert "Analysis" in step_types, f"Missing Analysis in {step_types}" + assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}" + assert "Observation" in step_types, f"Missing Observation in {step_types}" + assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}" + + # Observation should come after the sub-trace + subtrace_idx = step_types.index("sub-trace") + obs_idx = step_types.index("Observation") + assert obs_idx > subtrace_idx, "Observation should appear after sub-trace" From 4fb0b4d8e80abd766e0b94e95b234c7a7117d18f Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Apr 2026 20:16:53 +0100 Subject: [PATCH 25/37] Pub/sub abstraction: decouple from Pulsar (#751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove Pulsar-specific concepts from application code so that the pub/sub backend is swappable via configuration. Rename translators: - to_pulsar/from_pulsar → decode/encode across all translator classes, dispatch handlers, and tests (55+ files) - from_response_with_completion → encode_with_completion - Remove pulsar.schema.Record from translator base class Queue naming (CLASS:TOPICSPACE:TOPIC): - Replace topic() helper with queue() using new format: flow:tg:name, request:tg:name, response:tg:name, state:tg:name - Queue class implies persistence/TTL (no QoS in names) - Update Pulsar backend map_topic() to parse new format - Librarian queues use flow class (persistent, for chunking) - Config push uses state class (persistent, last-value) - Remove 15 dead topic imports from schema files - Update init_trustgraph.py namespace: config → state Confine Pulsar to pulsar_backend.py: - Delete legacy PulsarClient class from pubsub.py - Move add_args to add_pubsub_args() with standalone flag for CLI tools (defaults to localhost) - PulsarBackendConsumer.receive() catches _pulsar.Timeout, raises standard TimeoutError - Remove Pulsar imports from: async_processor, flow_processor, log_level, all 11 client files, 4 storage writers, gateway service, gateway config receiver - Remove log_level/LoggerLevel from client API - Rewrite tg-monitor-prompts to use backend abstraction - Update tg-dump-queues to use add_pubsub_args Also: pubsub-abstraction.md tech spec covering problem statement, design goals, as-is requirements, candidate broker assessment, approach, and implementation order. --- docs/tech-specs/pubsub-abstraction.md | 551 ++++++++++++++++++ .../test_document_embeddings_contract.py | 40 +- .../test_translator_completion_flags.py | 24 +- .../test_sync_document_embeddings_client.py | 4 +- .../unit/test_gateway/test_dispatch_config.py | 8 +- .../test_streaming_translators.py | 56 +- .../test_text_document_translator.py | 12 +- tests/unit/test_pubsub/test_queue_naming.py | 133 +++++ tests/unit/test_rdf/test_rdf_wire_format.py | 110 ++-- .../test_metadata_preservation.py | 30 +- .../test_message_translation.py | 16 +- trustgraph-base/trustgraph/base/__init__.py | 2 +- .../trustgraph/base/async_processor.py | 17 +- .../trustgraph/base/flow_processor.py | 2 - trustgraph-base/trustgraph/base/pubsub.py | 118 ++-- .../trustgraph/base/pulsar_backend.py | 47 +- .../trustgraph/clients/agent_client.py | 8 +- trustgraph-base/trustgraph/clients/base.py | 14 +- .../trustgraph/clients/config_client.py | 8 +- .../clients/document_embeddings_client.py | 8 +- .../trustgraph/clients/document_rag_client.py | 7 - .../trustgraph/clients/embeddings_client.py | 9 +- .../clients/graph_embeddings_client.py | 8 +- .../trustgraph/clients/graph_rag_client.py | 7 - .../trustgraph/clients/llm_client.py | 8 +- .../trustgraph/clients/prompt_client.py | 8 +- .../clients/row_embeddings_client.py | 8 +- .../clients/triples_query_client.py | 8 +- trustgraph-base/trustgraph/log_level.py | 10 +- .../trustgraph/messaging/translators/agent.py | 12 +- .../trustgraph/messaging/translators/base.py | 47 +- .../messaging/translators/collection.py | 8 +- .../messaging/translators/config.py | 12 +- .../messaging/translators/diagnosis.py | 12 +- .../messaging/translators/document_loading.py | 16 +- .../messaging/translators/embeddings.py | 12 +- .../messaging/translators/embeddings_query.py | 38 +- .../trustgraph/messaging/translators/flow.py | 12 +- .../messaging/translators/knowledge.py | 24 +- .../messaging/translators/library.py | 26 +- .../messaging/translators/metadata.py | 12 +- .../messaging/translators/nlp_query.py | 12 +- .../messaging/translators/primitives.py | 56 +- .../messaging/translators/prompt.py | 12 +- .../messaging/translators/retrieval.py | 24 +- .../messaging/translators/rows_query.py | 12 +- .../messaging/translators/structured_query.py | 12 +- .../messaging/translators/text_completion.py | 12 +- .../trustgraph/messaging/translators/tool.py | 12 +- .../messaging/translators/triples.py | 26 +- .../trustgraph/schema/core/topic.py | 31 +- .../trustgraph/schema/knowledge/document.py | 1 - .../trustgraph/schema/knowledge/embeddings.py | 1 - .../trustgraph/schema/knowledge/graph.py | 1 - .../trustgraph/schema/knowledge/knowledge.py | 10 +- .../trustgraph/schema/knowledge/nlp.py | 1 - .../trustgraph/schema/knowledge/object.py | 1 - .../trustgraph/schema/knowledge/rows.py | 1 - .../trustgraph/schema/knowledge/structured.py | 1 - .../trustgraph/schema/services/agent.py | 1 - .../trustgraph/schema/services/collection.py | 10 +- .../trustgraph/schema/services/config.py | 14 +- .../trustgraph/schema/services/flow.py | 10 +- .../trustgraph/schema/services/library.py | 10 +- .../trustgraph/schema/services/llm.py | 1 - .../trustgraph/schema/services/lookup.py | 1 - .../trustgraph/schema/services/nlp_query.py | 1 - .../trustgraph/schema/services/prompt.py | 1 - .../trustgraph/schema/services/query.py | 18 +- .../trustgraph/schema/services/retrieval.py | 1 - .../trustgraph/schema/services/rows_query.py | 1 - .../schema/services/structured_query.py | 1 - trustgraph-cli/trustgraph/cli/dump_queues.py | 51 +- .../trustgraph/cli/init_trustgraph.py | 2 +- .../trustgraph/cli/monitor_prompts.py | 94 ++- .../trustgraph/extract/kg/rows/processor.py | 2 +- .../trustgraph/gateway/config/receiver.py | 1 - .../trustgraph/gateway/dispatch/agent.py | 4 +- .../gateway/dispatch/collection_management.py | 4 +- .../trustgraph/gateway/dispatch/config.py | 4 +- .../dispatch/document_embeddings_import.py | 2 +- .../dispatch/document_embeddings_query.py | 4 +- .../gateway/dispatch/document_load.py | 2 +- .../gateway/dispatch/document_rag.py | 4 +- .../trustgraph/gateway/dispatch/embeddings.py | 4 +- .../trustgraph/gateway/dispatch/flow.py | 4 +- .../dispatch/graph_embeddings_query.py | 4 +- .../trustgraph/gateway/dispatch/graph_rag.py | 4 +- .../trustgraph/gateway/dispatch/knowledge.py | 4 +- .../trustgraph/gateway/dispatch/librarian.py | 4 +- .../trustgraph/gateway/dispatch/mcp_tool.py | 4 +- .../trustgraph/gateway/dispatch/nlp_query.py | 4 +- .../trustgraph/gateway/dispatch/prompt.py | 4 +- .../gateway/dispatch/row_embeddings_query.py | 4 +- .../trustgraph/gateway/dispatch/rows_query.py | 4 +- .../trustgraph/gateway/dispatch/serialize.py | 8 +- .../gateway/dispatch/structured_diag.py | 4 +- .../gateway/dispatch/structured_query.py | 4 +- .../gateway/dispatch/text_completion.py | 4 +- .../trustgraph/gateway/dispatch/text_load.py | 2 +- .../gateway/dispatch/triples_query.py | 4 +- trustgraph-flow/trustgraph/gateway/service.py | 1 - .../storage/triples/cassandra/write.py | 1 - .../storage/triples/falkordb/write.py | 1 - .../storage/triples/memgraph/write.py | 1 - .../trustgraph/storage/triples/neo4j/write.py | 1 - 106 files changed, 1269 insertions(+), 788 deletions(-) create mode 100644 docs/tech-specs/pubsub-abstraction.md create mode 100644 tests/unit/test_pubsub/test_queue_naming.py diff --git a/docs/tech-specs/pubsub-abstraction.md b/docs/tech-specs/pubsub-abstraction.md new file mode 100644 index 00000000..722b3b47 --- /dev/null +++ b/docs/tech-specs/pubsub-abstraction.md @@ -0,0 +1,551 @@ +# Pub/Sub Abstraction: Broker-Independent Messaging + +## Problem + +TrustGraph's messaging infrastructure is deeply coupled to Apache Pulsar in ways that go beyond the transport layer. This coupling creates several concrete problems. + +### 1. Schema system is Pulsar-native + +Every message type in the system is defined as a `pulsar.schema.Record` subclass using Pulsar field types (`String()`, `Integer()`, `Boolean()`, etc.). This means: + +- The `pulsar` Python package is a build dependency for `trustgraph-base`, even though `trustgraph-base` contains no transport logic +- Any code that imports a message schema transitively depends on Pulsar +- The schema definitions cannot be reused with a different broker without the Pulsar library installed +- What's actually happening on the wire is JSON serialisation — the Pulsar schema machinery adds complexity without adding value over plain JSON encode/decode + +### 2. Translators are named after the broker + +The translator layer that converts between internal Python objects and wire format uses methods called `to_pulsar()` and `from_pulsar()`. These are really just JSON encode/decode operations — they have nothing to do with Pulsar specifically. The naming creates a false impression that the translation is broker-specific, when in reality any broker that carries JSON payloads would use identical logic. + +### 3. Queue names use Pulsar URI format + +Queue identifiers throughout the codebase use Pulsar's `persistent://tenant/namespace/topic` or `non-persistent://tenant/namespace/topic` URI format. These are hardcoded in schema definitions and referenced across services. RabbitMQ, Redis Streams, or any other broker would use completely different naming conventions. There is no abstraction between the logical identity of a queue and its broker-specific address. + +### 4. Broker selection is not configurable + +There is no mechanism to select a different pub/sub backend at deployment time. The Pulsar client is instantiated directly in the gateway and via `PulsarClient` in the base processor. Switching to a different broker would require code changes across multiple packages, not a configuration change. + +### 5. Architectural requirements are implicit + +TrustGraph relies on specific pub/sub behaviours — shared subscriptions for load balancing, message acknowledgement for reliability, message properties for correlation — but these requirements are not documented. This makes it difficult to evaluate whether a candidate broker (RabbitMQ, Redis Streams, NATS, etc.) actually satisfies the system's needs, or where the gaps would be. + +## Design Goals + +### Goal 1: Remove the link between Pulsar schemas and application code + +Message types should be plain Python objects (dataclasses) that know how to serialise to and from JSON. The `pulsar.schema.Record` base class and Pulsar field types should not appear in schema definitions. The pub/sub transport layer sends and receives JSON bytes; the schema layer handles the mapping between JSON and typed Python objects independently. + +### Goal 2: Remove `to_pulsar` / `from_pulsar` naming + +The translator methods should reflect what they actually do: encode a Python object to a JSON-compatible dict, and decode a JSON-compatible dict back to a Python object. The naming should be broker-neutral (e.g. `encode` / `decode`, or `to_dict` / `from_dict`). + +### Goal 3: Schema objects provide encode/decode + +Each message type should be a Python dataclass (or similar) with a well-defined mapping to and from JSON. For example: + +```python +@dataclass +class TextCompletionRequest: + system: str + prompt: str + streaming: bool = False +``` + +Given `{"system": "You are helpful", "prompt": "Hello", "streaming": false}` on the wire, decoding produces an object where `request.system` is `"You are helpful"`, `request.prompt` is `"Hello"`, and `request.streaming` is `False`. Encoding does the reverse. This is the schema's concern, not the broker's. + +### Goal 4: Abstract queue naming + +Queue identifiers should not use Pulsar URI format (`persistent://tg/flow/topic`). A broker-neutral naming scheme is needed so that each backend can map logical queue names to its native format. The right approach here is not yet clear and needs to be worked through — considerations include how to express quality-of-service, multi-tenancy, and namespace separation without leaking broker concepts. + +### Goal 5: Document pub/sub architectural requirements + +TrustGraph's actual requirements from the pub/sub layer need to be formally specified. This includes: + +- **Delivery semantics**: Which queues need at-least-once delivery? Are any fire-and-forget? +- **Consumer patterns**: Shared subscriptions (competing consumers for load balancing), exclusive subscriptions, fan-out/broadcast +- **Message acknowledgement**: Positive ack, negative ack (redelivery), timeout-based redelivery +- **Message properties**: Key-value metadata on messages used for correlation (e.g. request IDs, flow routing) +- **Ordering guarantees**: Per-topic ordering, per-key ordering, or no ordering required +- **Message size**: Typical and maximum message sizes (some payloads include base64-encoded documents) +- **Persistence**: Which messages must survive broker restarts +- **Consumer positioning**: Ability to consume from earliest (replay) vs latest (live tail) +- **Connection model**: Long-lived connections with reconnection, or transient + +Documenting these requirements makes it possible to evaluate RabbitMQ or any other candidate against concrete criteria rather than discovering gaps during implementation. + +## Pub/Sub Architectural Requirements (As-Is) + +This section documents what TrustGraph currently needs from its pub/sub layer. These are the as-is requirements — some may be revisited or relaxed in a future design if it makes broker portability easier. + +### Consumer model + +All consumers use **shared subscriptions** (competing consumers). Multiple instances of the same processor read from the same subscription, and each message is delivered to exactly one instance. This is the load-balancing mechanism. + +No exclusive or failover subscriptions are used anywhere in the codebase, despite infrastructure support for them. + +Consumers support configurable concurrency — multiple async tasks within a single process can independently call `receive()` on the same subscription. + +### Delivery semantics + +Almost all queues are **non-persistent / best-effort (q0)**. The only persistent queue is `config_push_queue` (q2, exactly-once), which pushes full configuration state to processors. Since config pushes are idempotent (full state, not deltas), the persistence requirement here is about surviving broker restarts, not about exactly-once semantics per se. + +Flow processing queues (request/response pairs for LLM, RAG, agent, etc.) are all non-persistent. Messages in flight are lost on broker restart. This is acceptable because: + +- Requests originate from a client that will time out and retry +- There is no durable work-in-progress that would be corrupted by message loss +- The system is designed for real-time query processing, not batch pipelines + +### Message acknowledgement + +**Positive acknowledgement**: After successful handler execution, the message is acknowledged. This removes it from the subscription. + +**Negative acknowledgement**: On handler failure (unhandled exception or rate-limit timeout), the message is negatively acknowledged, which triggers redelivery by the broker. Rate-limited messages retry for up to 7200 seconds before giving up and negatively acknowledging. + +**Orphaned messages**: In the request-response subscriber pattern, messages that arrive with no matching waiter (e.g. the requester timed out) are positively acknowledged and discarded. This prevents redelivery storms. + +### Message properties + +Messages carry a small set of key-value string properties as metadata, separate from the payload. The primary use is a `"id"` property for request-response correlation — the requester generates a unique ID, attaches it as a property, and the responder echoes it back so the subscriber can match responses to waiters. + +Agent orchestration correlation (`correlation_id`, `parent_session_id`) is carried in the message payload, not in properties. + +### Consumer positioning + +Two modes are used: + +- **Earliest**: The configuration consumer starts from the beginning of the topic to receive full configuration history on startup. This is the only use of earliest positioning. +- **Latest** (default): All flow consumers start from the current position, processing only new messages. + +### Message ordering + +**Not required.** The codebase explicitly does not depend on message ordering: + +- Shared subscriptions distribute messages across consumers without ordering guarantees +- Concurrent handler tasks within a consumer process messages in arbitrary order +- Request-response correlation uses IDs, not positional ordering +- The supervisor fan-out/fan-in pattern collects results in a dictionary, order-independent +- Configuration pushes are full state snapshots, not ordered deltas + +### Message sizes + +Most messages are small JSON payloads (< 10KB). The exceptions: + +- **Document content**: Large documents (PDFs, text files) can be sent through the chunking service with base64 encoding. Pulsar's chunking feature (`chunking_enabled`) handles automatic splitting of oversized messages. +- **Agent observations**: LLM-generated text can be several KB but rarely exceeds typical message size limits. + +A replacement broker needs to either support large messages natively or provide a chunking/streaming mechanism. Alternatively, the large-document path could be refactored to use a side-channel (e.g. object store reference) instead of inline payload. + +### Fan-out patterns + +**Supervisor fan-out**: One supervisor request decomposes into N independent sub-agent requests, each emitted as a separate message on the agent request queue. Different agent instances pick them up via the shared subscription. A correlation ID links the completions back to the original decomposition. This is not pub/sub fan-out (one message to many consumers) — it's application-level fan-out (many messages to one queue). + +**Request-response isolation**: Each client creates a unique subscription name on response queues so it only receives its own responses. This means the response queue effectively has many independent subscribers, each seeing a filtered subset of messages based on the `"id"` property match. + +### Reconnection and resilience + +Reconnection logic lives in the Consumer/Producer/Publisher/Subscriber classes, not in the broker client. These classes handle: + +- Automatic reconnection on connection loss +- Retry loops with backoff +- Graceful shutdown (unsubscribe, close) + +The broker client itself is expected to provide a basic connection that can fail, and the wrapper classes handle recovery. This is important for the abstraction — the backend interface can be simple because resilience is handled above it. + +### Queue inventory + +| Queue | Persistence | Purpose | +|-------|-------------|---------| +| config push | Persistent (q2) | Full configuration state broadcast | +| config request/response | Non-persistent | Configuration queries | +| flow request/response | Non-persistent | Flow management | +| knowledge request/response | Non-persistent | Knowledge graph operations | +| librarian request/response | Non-persistent | Document storage operations | +| document embeddings request/response | Non-persistent | Document vector queries | +| row embeddings request/response | Non-persistent | Row vector queries | +| collection request/response | Non-persistent | Collection management | + +Additionally, each processing service (LLM, RAG, agent, prompt, embeddings, etc.) has dynamically defined request/response queue pairs configured at deployment time. + +### Summary of hard requirements for a replacement broker + +1. **Shared subscription / competing consumers** — multiple consumers on one queue, each message delivered to exactly one +2. **Message acknowledgement** — positive ack (remove from queue) and negative ack (trigger redelivery) +3. **Message properties** — key-value metadata on messages, at minimum a string `"id"` field +4. **Two consumer start positions** — from beginning of topic and from current position +5. **Persistence for at least one queue** — config state must survive broker restart +6. **Messages up to several MB** — or a chunking mechanism for large payloads +7. **No ordering requirement** — simplifies broker selection significantly + +## Candidate Brokers + +A quick assessment of alternatives against the hard requirements above. + +### RabbitMQ + +The primary candidate. Mature, widely deployed, well understood. + +- **Competing consumers**: Yes — multiple consumers on a queue, round-robin delivery. This is RabbitMQ's native model. +- **Acknowledgement**: Yes — `basic.ack` and `basic.nack` with requeue flag. +- **Message properties**: Yes — headers and properties on every message. The `correlation_id` and `message_id` fields are first-class concepts. +- **Consumer positioning**: Yes, via RabbitMQ Streams (3.9+). Streams are append-only logs that support reading from any offset — beginning, end, or timestamp. Classic queues are consumed destructively (no replay), but streams solve this cleanly. The `state` queue class maps to a RabbitMQ stream. Additionally, the Last Value Cache Exchange plugin can retain the most recent message per routing key for new consumers. +- **Persistence**: Yes — durable queues and persistent messages survive broker restart. +- **Large messages**: No hard limit but not designed for very large payloads. Practical limit around 128MB with default config. Adequate for current use. +- **Ordering**: FIFO per queue (stronger than required). +- **Operational complexity**: Low. Single binary, no ZooKeeper/BookKeeper dependencies. Significantly simpler to operate than Pulsar. +- **Ecosystem**: Excellent client libraries, management UI, mature tooling. + +**Gaps**: None significant. RabbitMQ Streams cover the replay/earliest positioning requirement. + +### Apache Kafka + +High-throughput distributed log. More infrastructure than TrustGraph likely needs. + +- **Competing consumers**: Yes — consumer groups with partition assignment. +- **Acknowledgement**: Yes — offset commits. No per-message negative ack; failed messages require application-level retry or dead-letter handling. +- **Message properties**: Yes — message headers (key-value byte arrays). +- **Consumer positioning**: Yes — seek to earliest or latest offset. Supports full replay. +- **Persistence**: Yes — all messages are persisted to the log by default. +- **Large messages**: Configurable (`max.message.bytes`), default 1MB, can be increased. Large payloads are discouraged by design. +- **Ordering**: Per-partition ordering (stronger than required). +- **Operational complexity**: High. Requires ZooKeeper (or KRaft), partition management, replication config. Overkill for typical TrustGraph deployments. +- **Ecosystem**: Excellent client libraries, schema registry, Connect framework. + +**Gaps**: No native negative acknowledgement. Operational complexity is high for small-to-medium deployments. Partition count must be planned upfront for parallelism. + +### Redis Streams + +Lightweight option using Redis as a message broker. + +- **Competing consumers**: Yes — consumer groups with `XREADGROUP`. +- **Acknowledgement**: Yes — `XACK`. Pending entries list tracks unacknowledged messages. No explicit negative ack but unacknowledged messages can be claimed after timeout via `XAUTOCLAIM`. +- **Message properties**: No native separation between properties and payload. Would need to encode properties as fields within the stream entry or in the payload. +- **Consumer positioning**: Yes — `0` (earliest) or `$` (latest) on group creation. +- **Persistence**: Yes — Redis persistence (RDB/AOF), though Redis is primarily an in-memory system. +- **Large messages**: Practical limit tied to Redis memory. Not suited for large payloads. +- **Ordering**: Per-stream ordering (stronger than required). +- **Operational complexity**: Low if Redis is already in the stack. No additional infrastructure. + +**Gaps**: No native message properties. Memory-bound. Persistence depends on Redis configuration. Not a natural fit for message broker patterns. + +### NATS / NATS JetStream + +Lightweight, high-performance messaging. JetStream adds persistence. + +- **Competing consumers**: Yes — queue groups in core NATS; consumer groups in JetStream. +- **Acknowledgement**: JetStream only — `Ack`, `Nak` (with redelivery), `InProgress` (extend timeout). +- **Message properties**: Yes — message headers (key-value). +- **Consumer positioning**: JetStream — deliver all, deliver last, deliver new, deliver by sequence/time. +- **Persistence**: JetStream only. Core NATS is fire-and-forget. +- **Large messages**: Default 1MB, configurable up to 64MB. +- **Ordering**: Per-subject ordering. +- **Operational complexity**: Very low. Single binary, no dependencies. Clustering is straightforward. + +**Gaps**: Requires JetStream for persistence and acknowledgement. Smaller ecosystem than RabbitMQ/Kafka. + +### Assessment Summary + +| Requirement | RabbitMQ | Kafka | Redis Streams | NATS JetStream | +|---|---|---|---|---| +| Competing consumers | Yes | Yes | Yes | Yes | +| Positive/negative ack | Yes | Partial | Partial | Yes | +| Message properties | Yes | Yes | No | Yes | +| Earliest positioning | Yes (Streams) | Yes | Yes | Yes | +| Persistence | Yes | Yes | Partial | Yes | +| Large messages | Yes | Configurable | No | Configurable | +| Operational simplicity | Good | Poor | Good | Good | + +**RabbitMQ** is the strongest candidate given TrustGraph's requirements and deployment profile. The only gap (earliest consumer positioning for config) has known workarounds. Operational simplicity is a significant advantage over Pulsar. + +## Approach + +### Current state + +The codebase has already undergone a partial abstraction. The picture is better than the problem statement might suggest: + +- **Backend abstraction exists**: `backend.py` defines Protocol-based interfaces (`PubSubBackend`, `BackendProducer`, `BackendConsumer`, `Message`). The Pulsar implementation lives in `pulsar_backend.py`. +- **Schemas are already dataclasses**: Message types in `schema/services/*.py` are plain Python dataclasses with type hints, not Pulsar `Record` subclasses. This was the hardest part of the old spec and it's done. +- **Serialization is JSON-based**: `pulsar_backend.py` contains `dataclass_to_dict()` and `dict_to_dataclass()` helpers that handle the round-trip. The wire format is JSON. +- **Factory pattern exists**: `pubsub.py` has `get_pubsub()` which creates a backend from configuration. Currently only Pulsar is implemented. +- **Consumer/Producer/Publisher/Subscriber are backend-agnostic**: These classes accept a `backend` parameter and delegate transport operations to it. They own retry, reconnection, metrics, and concurrency. + +What remains is cleanup, not a rewrite. + +### What needs to change + +#### 1. Rename translator methods + +The translator base class (`messaging/translators/base.py`) defines `to_pulsar()` and `from_pulsar()` as abstract methods. Every translator implements these. The methods convert between external API dicts and internal dataclass objects — nothing Pulsar-specific happens in them. + +**Change**: Rename to `decode()` (external dict → dataclass) and `encode()` (dataclass → external dict). Update all translator subclasses and all call sites. + +This is a mechanical rename. The method bodies don't change. + +#### 2. Rename translator base classes + +The base classes `Translator`, `MessageTranslator`, and `SendTranslator` reference "pulsar" in docstrings and parameter names. Clean these up so the naming reflects what the layer actually does: translating between the external API representation (JSON dicts from HTTP/WebSocket) and the internal schema (dataclasses). + +#### 3. Move serialization out of the Pulsar backend + +`dataclass_to_dict()` and `dict_to_dataclass()` currently live in `pulsar_backend.py` but are not Pulsar-specific. They handle the conversion between dataclasses and JSON-compatible dicts, which every backend needs. + +**Change**: Move these to a shared location (e.g. `trustgraph/base/serialization.py` or alongside the schema definitions). The backend interface sends and receives dicts; serialization to/from dataclasses happens at a layer above. + +This means the backend Protocol simplifies: `send()` accepts a dict and properties, `value()` returns a dict. The Consumer/Producer layer handles dataclass ↔ dict conversion using the shared serializers. + +#### 4. Abstract queue naming + +Queue names currently use the format `q0/tg/flow/queue-name` or `q2/tg/config/queue-name`, which the Pulsar backend maps to `non-persistent://tg/flow/queue-name` or `persistent://tg/config/queue-name`. + +This is an open design question. Options: + +**Option A: Simple string names.** Queues are just strings like `"text-completion-request"`. The backend is responsible for mapping to its native format (Pulsar adds `persistent://tg/flow/` prefix, RabbitMQ uses the string as-is or adds a vhost prefix). Persistence and namespace are configuration concerns, not embedded in the name. + +**Option B: Structured queue descriptor.** A small object that carries the logical name plus metadata: + +```python +@dataclass +class QueueDescriptor: + name: str # e.g. "text-completion-request" + namespace: str = "flow" # logical grouping + persistent: bool = False # must survive broker restart +``` + +The backend maps this to its native format. + +**Option C: Keep the current format** (`q0/tg/flow/name`) but document it as a TrustGraph convention, not a Pulsar convention. Backends parse it. + +Option B is the most explicit. Option A is the simplest. Either is workable. The key constraint is that persistence is a property of the queue definition, not a runtime choice — the config push queue is persistent, everything else is not. + +#### 5. Implement RabbitMQ backend + +Write `rabbitmq_backend.py` implementing the `PubSubBackend` Protocol: + +- **`create_producer()`**: Creates a channel and declares the target queue. `send()` publishes to the default exchange with the queue name as routing key. Properties map to AMQP basic properties (specifically `message_id` for the `"id"` property). +- **`create_consumer()`**: Declares the queue and starts consuming with `basic_consume`. Shared subscription is the default RabbitMQ model — multiple consumers on one queue get round-robin delivery. `acknowledge()` maps to `basic_ack`, `negative_acknowledge()` maps to `basic_nack` with `requeue=True`. +- **Persistence**: For persistent queues, declare as durable with `delivery_mode=2` on messages. For non-persistent queues, declare as non-durable. +- **Consumer positioning**: RabbitMQ queues are consumed destructively, so "earliest" doesn't apply in the Pulsar sense. For the config push use case, use a **fanout exchange with per-consumer exclusive queues** — each new processor gets its own queue that receives all config publishes, plus the last-value can be handled by having the config service re-publish on startup. +- **Large messages**: RabbitMQ handles messages up to `rabbit.max_message_size` (default 128MB). No chunking needed. + +The factory in `pubsub.py` gets a new branch: + +```python +if backend_type == 'rabbitmq': + return RabbitMQBackend( + host=config.get('rabbitmq_host'), + port=config.get('rabbitmq_port'), + username=config.get('rabbitmq_username'), + password=config.get('rabbitmq_password'), + vhost=config.get('rabbitmq_vhost', '/'), + ) +``` + +Backend selection via `PUBSUB_BACKEND=rabbitmq` environment variable or `--pubsub-backend rabbitmq` CLI flag. + +#### 6. Clean up remaining Pulsar references + +After the above changes, Pulsar-specific code should be confined to: + +- `pulsar_backend.py` — the Pulsar implementation +- `pubsub.py` — the factory that imports it + +Audit and remove any remaining Pulsar imports, Pulsar exception handling, or Pulsar-specific concepts from: + +- `async_processor.py` (currently catches `_pulsar.Interrupted`) +- `consumer.py`, `subscriber.py` (if any Pulsar exceptions leak through) +- Schema files (should be clean already, but verify) +- Gateway service (currently instantiates Pulsar client directly) + +The gateway is a special case — it currently bypasses the abstraction layer and creates a Pulsar client directly for dispatching API requests. It should use the same `get_pubsub()` factory as everything else. + +### What stays the same + +- **Schema definitions**: Already dataclasses. No changes needed. +- **Consumer/Producer/Publisher/Subscriber**: Already backend-agnostic. No changes to their core logic. +- **FlowProcessor and spec wiring**: Already uses `processor.pubsub` to create backend instances. No changes. +- **Backend Protocol**: The interface in `backend.py` is sound. Minor refinement possible (dict vs dataclass at the boundary) but the shape is right. + +### Concrete cleanups + +The following files have Pulsar-specific imports that should not be there after the abstraction is complete. Pulsar imports should be confined to `pulsar_backend.py` and the factory in `pubsub.py`. + +**Dead imports (unused, can just be removed):** + +- `trustgraph-base/trustgraph/base/pubsub.py` — `from pulsar.schema import JsonSchema`, `import pulsar`, `import _pulsar`. The `JsonSchema` import is unused since the switch to `BytesSchema`. The `pulsar`/`_pulsar` imports are only used by the legacy `PulsarClient` class which should be removed (superseded by `PulsarBackend`). +- `trustgraph-base/trustgraph/base/flow_processor.py` — `from pulsar.schema import JsonSchema`. Unused. + +**Legacy `PulsarClient` class:** + +- `trustgraph-base/trustgraph/base/pubsub.py` — The `PulsarClient` class is a leftover from before the backend abstraction. `get_pubsub()` still references `PulsarClient.default_pulsar_host` for defaults. Move the defaults to `PulsarBackend` or to environment variable reads in the factory, then delete `PulsarClient`. + +**Client libraries using Pulsar directly:** + +- `trustgraph-base/trustgraph/clients/base.py` — `import pulsar`, `import _pulsar`, `from pulsar.schema import JsonSchema`. This is the base class for the old synchronous client library. These clients predate the backend abstraction and use Pulsar directly. +- `trustgraph-base/trustgraph/clients/embeddings_client.py` — `from pulsar.schema import JsonSchema`, `import _pulsar`. +- `trustgraph-base/trustgraph/clients/*.py` (agent, config, document_embeddings, document_rag, graph_embeddings, graph_rag, llm, prompt, row_embeddings, triples_query) — all import `_pulsar` for exception handling. + +These clients are the internal request-response clients used by processors. They need to be migrated to use the backend abstraction or their Pulsar exception handling needs to be wrapped behind a backend-agnostic exception type. + +**Translator base class:** + +- `trustgraph-base/trustgraph/messaging/translators/base.py` — `from pulsar.schema import Record`. Used in type hints. Should be removed when `to_pulsar`/`from_pulsar` are renamed. + +**Gateway service (bypasses abstraction):** + +- `trustgraph-flow/trustgraph/gateway/service.py` — `import pulsar`. Creates a Pulsar client directly. +- `trustgraph-flow/trustgraph/gateway/config/receiver.py` — `import pulsar`. Direct Pulsar usage. + +The gateway should use `get_pubsub()` like everything else. + +**Storage writers:** + +- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py` — `import pulsar` +- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py` — `import pulsar` +- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py` — `import pulsar` +- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py` — `import pulsar` + +These need investigation — likely Pulsar exception handling or direct client usage that should go through the abstraction. + +**Log level:** + +- `trustgraph-base/trustgraph/log_level.py` — `import _pulsar`. Used to set Pulsar's log level. Should be moved into `pulsar_backend.py`. + +### Queue naming + +The current scheme encodes QoS, tenant, namespace, and queue name into a slash-separated string (`q0/tg/request/config`) which the Pulsar backend parses and maps to a Pulsar URI (`non-persistent://tg/request/config`). This was an attempt at abstraction but it has problems: + +- QoS in the name was a mistake — it's a property of the queue definition, not something that belongs in the name. A queue is either persistent or it isn't; that's decided once when the queue is defined. +- The tenant/namespace structure mirrors Pulsar's model. RabbitMQ doesn't use this — it has vhosts and exchange/queue names. Pretending the naming isn't TrustGraph-specific just leaks Pulsar concepts. +- The `topic()` helper generates these strings, and the backend parses them apart. This is unnecessary indirection. + +There are two categories of queue in TrustGraph: + +**Infrastructure queues** — defined in code, used for system services. These are fixed and well-known: + +| Queue | Persistent | Purpose | +|-------|------------|---------| +| `config-request` | No | Config queries | +| `config-response` | No | Config query responses | +| `config-push` | Yes | Config state broadcast | +| `flow-request` | No | Flow management queries | +| `flow-response` | No | Flow management responses | +| `librarian-request` | No | Document storage operations | +| `librarian-response` | No | Document storage responses | +| `knowledge-request` | No | Knowledge graph operations | +| `knowledge-response` | No | Knowledge graph responses | +| `document-embeddings-request` | No | Document vector queries | +| `document-embeddings-response` | No | Document vector responses | +| `row-embeddings-request` | No | Row vector queries | +| `row-embeddings-response` | No | Row vector responses | +| `collection-request` | No | Collection management | +| `collection-response` | No | Collection management responses | + +**Flow queues** — defined in configuration, created dynamically per flow. The queue names come from the config service (e.g. `text-completion-request`, `graph-rag-request`, `agent-request`). Each flow instance has its own set of these queues. + +For infrastructure queues, the name is just a string. Persistence is a property of the queue definition, not encoded in the name. The backend maps the name to whatever its native format requires. + +For flow queues, the name comes from configuration. The config service already distributes queue names as strings — the backend just needs to be able to use them. + +#### Proposed scheme: CLASS:TOPICSPACE:TOPIC + +A queue name has three parts separated by colons: + +- **CLASS** — a small enum that defines the queue's operational characteristics. The backend knows what each class means in terms of persistence, TTL, memory limits, etc. There are only four classes: + + | Class | Persistent | TTL | Behaviour | + |-------|------------|-----|-----------| + | `flow` | Yes | Long | Processing pipeline queues. Messages survive broker restart. | + | `request` | No | Short | Transient request-response. Low TTL, no persistence needed — clients retry on failure. | + | `response` | No | Short | Same as request, for the response side. | + | `state` | Yes | Retained | Last-value state broadcast. Consumers need the most recent value on startup, plus any future updates. Config push is the primary example. | + +- **TOPICSPACE** — deployment isolation. Keeps different TrustGraph deployments separate when sharing the same pub/sub infrastructure. Most deployments just use `tg`. Avoids the overloaded terms "tenant" and "namespace". + +- **TOPIC** — the logical queue identity. What the queue is for. + +**Examples:** + +``` +flow:tg:text-completion-request +flow:tg:graph-rag-request +flow:tg:agent-request +request:tg:librarian +response:tg:librarian +request:tg:config +response:tg:config +state:tg:config +request:tg:flow +response:tg:flow +``` + +**Backend mapping:** + +Each backend parses the three parts and maps them to its native concepts: + +- **Pulsar**: `flow:tg:text-completion-request` → `persistent://tg/flow/text-completion-request`. Class maps to persistent/non-persistent and namespace. State class uses persistent topic with earliest consumer positioning. +- **RabbitMQ**: Topicspace maps to vhost. Class determines queue durability and TTL policy. State class uses a last-value queue (via plugin) or a fanout exchange pattern where each consumer gets the retained state on connect. +- **Kafka**: `flow.tg.text-completion-request` as topic name. Class determines retention and compaction policy. State class maps to a compacted topic (last value per key). + +**Why this works:** + +- The class enum is small and stable — adding a new class is rare and deliberate +- Queue properties (persistence, TTL) are implied by class, not encoded in the name +- Dynamic registration works naturally — the config service publishes `flow:tg:text-completion-request` and the backend knows how to declare it from the `flow` class +- The colon separator is unambiguous, easy to split, doesn't conflict with URIs or path separators that backends use internally +- No pretence of being generic — this is a TrustGraph convention, and that's fine + +### Serialization boundary + +**Decision: the backend owns the wire format.** + +The contract between the Consumer/Producer layer and the backend is dataclass objects in, dataclass objects out: + +- `send()` accepts a dataclass instance and properties dict +- `receive()` returns a message whose `value()` is a dataclass instance + +What happens on the wire is the backend's concern. The Pulsar backend uses JSON (via `dataclass_to_dict` / `dict_to_dataclass`). A RabbitMQ backend would likely also use JSON. A future backend could use Protobuf, MessagePack, or Avro if the broker benefits from it. + +The serialization helpers stay inside the backend that uses them — they are not shared infrastructure. Each backend brings its own serialization strategy. The Consumer/Producer layer never thinks about wire format. + +### Gateway service + +**Decision: the gateway uses the backend abstraction like any other component.** + +The gateway currently bridges WebSocket/REST to Pulsar directly, bypassing the abstraction layer. It translates incoming API JSON to Pulsar schema objects, sends them, receives responses as Pulsar schema objects, and translates back to API JSON. Since the wire format is JSON in both directions, this is effectively a no-op round trip through the schema machinery. + +With the backend abstraction, the gateway follows the same pattern as every other component: + +1. Incoming API JSON → translator `decode()` → dataclass +2. Dataclass → backend `send()` (backend handles wire format) +3. Backend `receive()` → dataclass +4. Dataclass → translator `encode()` → API JSON → WebSocket/REST client + +This is architecturally simple — one code path, no special cases. The gateway depends on the schema dataclasses and the translator layer, which it already does. The overhead of deserialize-then-reserialize is negligible for the message sizes involved. And it keeps all options open — if a future backend uses a non-JSON wire format, the gateway still works without changes. + +## Implementation Order + +### Phase 1: Rename translators + +Rename `to_pulsar()` → `decode()`, `from_pulsar()` → `encode()` across all translator classes and call sites. Remove `from pulsar.schema import Record` from the translator base class. Mechanical find-and-replace, no behavioural changes. + +### Phase 2: Queue naming + +Replace the `topic()` helper with the CLASS:TOPICSPACE:TOPIC scheme. Update all queue definitions in `schema/services/*.py` and `schema/knowledge/*.py`. Update `PulsarBackend.map_topic()` to parse the new format. Verify all existing functionality still works with Pulsar. + +### Phase 3: Clean up Pulsar leaks + +Work through the concrete cleanups list: remove dead imports, delete the legacy `PulsarClient` class, migrate the client libraries and gateway to use the backend abstraction. After this phase, `pulsar` imports exist only in `pulsar_backend.py`. + +### Phase 4: RabbitMQ backend + +Implement `rabbitmq_backend.py` against the existing `PubSubBackend` Protocol. Map queue classes to RabbitMQ concepts: `flow` → durable queues, `request`/`response` → non-durable queues with TTL, `state` → RabbitMQ streams. Add `rabbitmq` as a backend option in the factory. Test end-to-end with `PUBSUB_BACKEND=rabbitmq`. + +Phases 1-3 are safe to do on main — they don't change behaviour, just clean up. Phase 4 is additive — it adds a new backend without touching the existing one. + +### Config distribution on RabbitMQ + +The `state` queue class needs "start from earliest" semantics — a newly started processor must receive the current configuration state. + +RabbitMQ Streams (available since 3.9) solve this directly. Streams are persistent, append-only logs that support consumer offset positioning. The RabbitMQ backend maps the `state` class to a stream, and consumers attach with offset `first` to read from the beginning, or `last` to read the most recent entry plus future updates. + +Since config pushes are full state snapshots (not deltas), a consumer only needs the most recent entry. The RabbitMQ backend can use `last` offset positioning for `state` class consumers, which delivers the last message in the stream followed by any new messages. This matches the current behaviour where processors read config on startup and then react to updates. + diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py index c7d6369a..b6d14124 100644 --- a/tests/contract/test_document_embeddings_contract.py +++ b/tests/contract/test_document_embeddings_contract.py @@ -38,7 +38,7 @@ class TestDocumentEmbeddingsRequestContract: assert request.user == "test_user" assert request.collection == "test_collection" - def test_request_translator_to_pulsar(self): + def test_request_translator_decode(self): """Test request translator converts dict to Pulsar schema""" translator = DocumentEmbeddingsRequestTranslator() @@ -49,7 +49,7 @@ class TestDocumentEmbeddingsRequestContract: "collection": "custom_collection" } - result = translator.to_pulsar(data) + result = translator.decode(data) assert isinstance(result, DocumentEmbeddingsRequest) assert result.vector == [0.1, 0.2, 0.3, 0.4] @@ -57,7 +57,7 @@ class TestDocumentEmbeddingsRequestContract: assert result.user == "custom_user" assert result.collection == "custom_collection" - def test_request_translator_to_pulsar_with_defaults(self): + def test_request_translator_decode_with_defaults(self): """Test request translator uses correct defaults""" translator = DocumentEmbeddingsRequestTranslator() @@ -66,7 +66,7 @@ class TestDocumentEmbeddingsRequestContract: # No limit, user, or collection provided } - result = translator.to_pulsar(data) + result = translator.decode(data) assert isinstance(result, DocumentEmbeddingsRequest) assert result.vector == [0.1, 0.2] @@ -74,7 +74,7 @@ class TestDocumentEmbeddingsRequestContract: assert result.user == "trustgraph" # Default assert result.collection == "default" # Default - def test_request_translator_from_pulsar(self): + def test_request_translator_encode(self): """Test request translator converts Pulsar schema to dict""" translator = DocumentEmbeddingsRequestTranslator() @@ -85,7 +85,7 @@ class TestDocumentEmbeddingsRequestContract: collection="test_collection" ) - result = translator.from_pulsar(request) + result = translator.encode(request) assert isinstance(result, dict) assert result["vector"] == [0.5, 0.6] @@ -134,7 +134,7 @@ class TestDocumentEmbeddingsResponseContract: assert response.error == error assert response.chunks == [] - def test_response_translator_from_pulsar_with_chunks(self): + def test_response_translator_encode_with_chunks(self): """Test response translator converts Pulsar schema with chunks to dict""" translator = DocumentEmbeddingsResponseTranslator() @@ -147,7 +147,7 @@ class TestDocumentEmbeddingsResponseContract: ] ) - result = translator.from_pulsar(response) + result = translator.encode(response) assert isinstance(result, dict) assert "chunks" in result @@ -155,7 +155,7 @@ class TestDocumentEmbeddingsResponseContract: assert result["chunks"][0]["chunk_id"] == "doc1/c1" assert result["chunks"][0]["score"] == 0.95 - def test_response_translator_from_pulsar_with_empty_chunks(self): + def test_response_translator_encode_with_empty_chunks(self): """Test response translator handles empty chunks list""" translator = DocumentEmbeddingsResponseTranslator() @@ -164,25 +164,25 @@ class TestDocumentEmbeddingsResponseContract: chunks=[] ) - result = translator.from_pulsar(response) + result = translator.encode(response) assert isinstance(result, dict) assert "chunks" in result assert result["chunks"] == [] - def test_response_translator_from_pulsar_with_none_chunks(self): + def test_response_translator_encode_with_none_chunks(self): """Test response translator handles None chunks""" translator = DocumentEmbeddingsResponseTranslator() response = MagicMock() response.chunks = None - result = translator.from_pulsar(response) + result = translator.encode(response) assert isinstance(result, dict) assert "chunks" not in result or result.get("chunks") is None - def test_response_translator_from_response_with_completion(self): + def test_response_translator_encode_with_completion(self): """Test response translator with completion flag""" translator = DocumentEmbeddingsResponseTranslator() @@ -194,7 +194,7 @@ class TestDocumentEmbeddingsResponseContract: ] ) - result, is_final = translator.from_response_with_completion(response) + result, is_final = translator.encode_with_completion(response) assert isinstance(result, dict) assert "chunks" in result @@ -202,12 +202,12 @@ class TestDocumentEmbeddingsResponseContract: assert result["chunks"][0]["chunk_id"] == "chunk1" assert is_final is True # Document embeddings responses are always final - def test_response_translator_to_pulsar_not_implemented(self): - """Test that to_pulsar raises NotImplementedError for responses""" + def test_response_translator_decode_not_implemented(self): + """Test that decode raises NotImplementedError for responses""" translator = DocumentEmbeddingsResponseTranslator() with pytest.raises(NotImplementedError): - translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]}) + translator.decode({"chunks": [{"chunk_id": "test", "score": 0.9}]}) class TestDocumentEmbeddingsMessageCompatibility: @@ -225,7 +225,7 @@ class TestDocumentEmbeddingsMessageCompatibility: # Convert to Pulsar request req_translator = DocumentEmbeddingsRequestTranslator() - pulsar_request = req_translator.to_pulsar(request_data) + pulsar_request = req_translator.decode(request_data) # Simulate service processing and creating response response = DocumentEmbeddingsResponse( @@ -238,7 +238,7 @@ class TestDocumentEmbeddingsMessageCompatibility: # Convert response back to dict resp_translator = DocumentEmbeddingsResponseTranslator() - response_data = resp_translator.from_pulsar(response) + response_data = resp_translator.encode(response) # Verify data integrity assert isinstance(pulsar_request, DocumentEmbeddingsRequest) @@ -261,7 +261,7 @@ class TestDocumentEmbeddingsMessageCompatibility: # Convert response to dict translator = DocumentEmbeddingsResponseTranslator() - response_data = translator.from_pulsar(response) + response_data = translator.encode(response) # Verify error handling assert isinstance(response_data, dict) diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py index a22e1c41..91ce1b77 100644 --- a/tests/contract/test_translator_completion_flags.py +++ b/tests/contract/test_translator_completion_flags.py @@ -33,7 +33,7 @@ class TestRAGTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is True, "is_final must be True when end_of_session=True" @@ -57,7 +57,7 @@ class TestRAGTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is False, "is_final must be False when end_of_session=False" @@ -80,7 +80,7 @@ class TestRAGTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is False @@ -103,7 +103,7 @@ class TestRAGTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is False, "end_of_stream=True should NOT make is_final=True" @@ -125,7 +125,7 @@ class TestRAGTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is True, "is_final must be True when end_of_session=True" @@ -147,7 +147,7 @@ class TestRAGTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is False, "end_of_stream=True should NOT make is_final=True" @@ -168,7 +168,7 @@ class TestRAGTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is False, "is_final must be False when end_of_stream=False" @@ -195,7 +195,7 @@ class TestAgentTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is True, "is_final must be True when end_of_dialog=True" @@ -217,7 +217,7 @@ class TestAgentTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is False, "is_final must be False when end_of_dialog=False" @@ -240,7 +240,7 @@ class TestAgentTranslatorCompletionFlags: ) # Act - thought_dict, thought_is_final = translator.from_response_with_completion(thought_response) + thought_dict, thought_is_final = translator.encode_with_completion(thought_response) # Assert assert thought_is_final is False, "Thought message must not be final" @@ -254,7 +254,7 @@ class TestAgentTranslatorCompletionFlags: ) # Act - obs_dict, obs_is_final = translator.from_response_with_completion(observation_response) + obs_dict, obs_is_final = translator.encode_with_completion(observation_response) # Assert assert obs_is_final is False, "Observation message must not be final" @@ -275,7 +275,7 @@ class TestAgentTranslatorCompletionFlags: ) # Act - response_dict, is_final = translator.from_response_with_completion(response) + response_dict, is_final = translator.encode_with_completion(response) # Assert assert is_final is True, "Streaming format must use end_of_dialog for is_final" diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py index ce758f66..edf4ac81 100644 --- a/tests/unit/test_clients/test_sync_document_embeddings_client.py +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -21,17 +21,15 @@ class TestSyncDocumentEmbeddingsClient: # Act client = DocumentEmbeddingsClient( - log_level=1, subscriber="test-subscriber", input_queue="test-input", output_queue="test-output", pulsar_host="pulsar://test:6650", pulsar_api_key="test-key" ) - + # Assert mock_base_init.assert_called_once_with( - log_level=1, subscriber="test-subscriber", input_queue="test-input", output_queue="test-output", diff --git a/tests/unit/test_gateway/test_dispatch_config.py b/tests/unit/test_gateway/test_dispatch_config.py index 4fbd8484..11eb7484 100644 --- a/tests/unit/test_gateway/test_dispatch_config.py +++ b/tests/unit/test_gateway/test_dispatch_config.py @@ -49,7 +49,7 @@ class TestConfigRequestor: mock_translator_registry.get_response_translator.return_value = Mock() # Setup translator response - mock_request_translator.to_pulsar.return_value = "translated_request" + mock_request_translator.decode.return_value = "translated_request" # Patch ServiceRequestor async methods with regular mocks (not AsyncMock) with patch.object(ServiceRequestor, 'start', return_value=None), \ @@ -64,7 +64,7 @@ class TestConfigRequestor: result = requestor.to_request({"test": "body"}) # Verify translator was called correctly - mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"}) + mock_request_translator.decode.assert_called_once_with({"test": "body"}) assert result == "translated_request" @patch('trustgraph.gateway.dispatch.config.TranslatorRegistry') @@ -76,7 +76,7 @@ class TestConfigRequestor: mock_translator_registry.get_response_translator.return_value = mock_response_translator # Setup translator response - mock_response_translator.from_response_with_completion.return_value = "translated_response" + mock_response_translator.encode_with_completion.return_value = "translated_response" requestor = ConfigRequestor( backend=Mock(), @@ -89,5 +89,5 @@ class TestConfigRequestor: result = requestor.from_response(mock_message) # Verify translator was called correctly - mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message) + mock_response_translator.encode_with_completion.assert_called_once_with(mock_message) assert result == "translated_response" \ No newline at end of file diff --git a/tests/unit/test_gateway/test_streaming_translators.py b/tests/unit/test_gateway/test_streaming_translators.py index e190fe68..31912688 100644 --- a/tests/unit/test_gateway/test_streaming_translators.py +++ b/tests/unit/test_gateway/test_streaming_translators.py @@ -25,7 +25,7 @@ from trustgraph.schema import ( class TestGraphRagResponseTranslator: """Test GraphRagResponseTranslator streaming behavior""" - def test_from_pulsar_with_empty_response(self): + def test_encode_with_empty_response(self): """Test that empty response strings are preserved""" # Arrange translator = GraphRagResponseTranslator() @@ -36,14 +36,14 @@ class TestGraphRagResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert - Empty string should be included in result assert "response" in result assert result["response"] == "" assert result["end_of_stream"] is True - def test_from_pulsar_with_non_empty_response(self): + def test_encode_with_non_empty_response(self): """Test that non-empty responses work correctly""" # Arrange translator = GraphRagResponseTranslator() @@ -54,13 +54,13 @@ class TestGraphRagResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert result["response"] == "Some text" assert result["end_of_stream"] is False - def test_from_pulsar_with_none_response(self): + def test_encode_with_none_response(self): """Test that None response is handled correctly""" # Arrange translator = GraphRagResponseTranslator() @@ -71,14 +71,14 @@ class TestGraphRagResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert - None should not be included assert "response" not in result assert result["end_of_stream"] is True - def test_from_response_with_completion_returns_correct_flag(self): - """Test that from_response_with_completion returns correct is_final flag""" + def test_encode_with_completion_returns_correct_flag(self): + """Test that encode_with_completion returns correct is_final flag""" # Arrange translator = GraphRagResponseTranslator() @@ -90,7 +90,7 @@ class TestGraphRagResponseTranslator: ) # Act - result, is_final = translator.from_response_with_completion(response_chunk) + result, is_final = translator.encode_with_completion(response_chunk) # Assert assert is_final is False @@ -105,7 +105,7 @@ class TestGraphRagResponseTranslator: ) # Act - result, is_final = translator.from_response_with_completion(final_response) + result, is_final = translator.encode_with_completion(final_response) # Assert - is_final is based on end_of_session, not end_of_stream assert is_final is True @@ -116,7 +116,7 @@ class TestGraphRagResponseTranslator: class TestDocumentRagResponseTranslator: """Test DocumentRagResponseTranslator streaming behavior""" - def test_from_pulsar_with_empty_response(self): + def test_encode_with_empty_response(self): """Test that empty response strings are preserved""" # Arrange translator = DocumentRagResponseTranslator() @@ -127,14 +127,14 @@ class TestDocumentRagResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert "response" in result assert result["response"] == "" assert result["end_of_stream"] is True - def test_from_pulsar_with_non_empty_response(self): + def test_encode_with_non_empty_response(self): """Test that non-empty responses work correctly""" # Arrange translator = DocumentRagResponseTranslator() @@ -145,7 +145,7 @@ class TestDocumentRagResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert result["response"] == "Document content" @@ -155,7 +155,7 @@ class TestDocumentRagResponseTranslator: class TestPromptResponseTranslator: """Test PromptResponseTranslator streaming behavior""" - def test_from_pulsar_with_empty_text(self): + def test_encode_with_empty_text(self): """Test that empty text strings are preserved""" # Arrange translator = PromptResponseTranslator() @@ -167,14 +167,14 @@ class TestPromptResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert "text" in result assert result["text"] == "" assert result["end_of_stream"] is True - def test_from_pulsar_with_non_empty_text(self): + def test_encode_with_non_empty_text(self): """Test that non-empty text works correctly""" # Arrange translator = PromptResponseTranslator() @@ -186,13 +186,13 @@ class TestPromptResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert result["text"] == "Some prompt response" assert result["end_of_stream"] is False - def test_from_pulsar_with_none_text(self): + def test_encode_with_none_text(self): """Test that None text is handled correctly""" # Arrange translator = PromptResponseTranslator() @@ -204,14 +204,14 @@ class TestPromptResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert "text" not in result assert "object" in result assert result["end_of_stream"] is True - def test_from_pulsar_includes_end_of_stream(self): + def test_encode_includes_end_of_stream(self): """Test that end_of_stream flag is always included""" # Arrange translator = PromptResponseTranslator() @@ -225,7 +225,7 @@ class TestPromptResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert "end_of_stream" in result @@ -235,7 +235,7 @@ class TestPromptResponseTranslator: class TestTextCompletionResponseTranslator: """Test TextCompletionResponseTranslator streaming behavior""" - def test_from_pulsar_always_includes_response(self): + def test_encode_always_includes_response(self): """Test that response field is always included, even if empty""" # Arrange translator = TextCompletionResponseTranslator() @@ -249,13 +249,13 @@ class TestTextCompletionResponseTranslator: ) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert - Response should always be present assert "response" in result assert result["response"] == "" - def test_from_response_with_completion_with_empty_final(self): + def test_encode_with_completion_with_empty_final(self): """Test that empty final response is handled correctly""" # Arrange translator = TextCompletionResponseTranslator() @@ -269,7 +269,7 @@ class TestTextCompletionResponseTranslator: ) # Act - result, is_final = translator.from_response_with_completion(response) + result, is_final = translator.encode_with_completion(response) # Assert assert is_final is True @@ -297,7 +297,7 @@ class TestStreamingProtocolCompliance: response = response_class(**kwargs) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty" @@ -320,7 +320,7 @@ class TestStreamingProtocolCompliance: response = response_class(**kwargs) # Act - result = translator.from_pulsar(response) + result = translator.encode(response) # Assert assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag" diff --git a/tests/unit/test_gateway/test_text_document_translator.py b/tests/unit/test_gateway/test_text_document_translator.py index f836eb2b..84eedefc 100644 --- a/tests/unit/test_gateway/test_text_document_translator.py +++ b/tests/unit/test_gateway/test_text_document_translator.py @@ -8,11 +8,11 @@ from trustgraph.messaging.translators.document_loading import TextDocumentTransl class TestTextDocumentTranslator: - def test_to_pulsar_decodes_base64_text(self): + def test_decode_decodes_base64_text(self): translator = TextDocumentTranslator() payload = "Cancer survival: 2.74× higher hazard ratio" - msg = translator.to_pulsar( + msg = translator.decode( { "id": "doc-1", "user": "alice", @@ -27,11 +27,11 @@ class TestTextDocumentTranslator: assert msg.metadata.collection == "research" assert msg.text == payload.encode("utf-8") - def test_to_pulsar_accepts_raw_utf8_text(self): + def test_decode_accepts_raw_utf8_text(self): translator = TextDocumentTranslator() payload = "Cancer survival: 2.74× higher hazard ratio" - msg = translator.to_pulsar( + msg = translator.decode( { "charset": "utf-8", "text": payload, @@ -40,11 +40,11 @@ class TestTextDocumentTranslator: assert msg.text == payload.encode("utf-8") - def test_to_pulsar_falls_back_to_raw_non_base64_ascii(self): + def test_decode_falls_back_to_raw_non_base64_ascii(self): translator = TextDocumentTranslator() payload = "plain-text payload" - msg = translator.to_pulsar( + msg = translator.decode( { "charset": "utf-8", "text": payload, diff --git a/tests/unit/test_pubsub/test_queue_naming.py b/tests/unit/test_pubsub/test_queue_naming.py new file mode 100644 index 00000000..1ee781d9 --- /dev/null +++ b/tests/unit/test_pubsub/test_queue_naming.py @@ -0,0 +1,133 @@ +""" +Tests for queue naming and topic mapping. +""" + +import pytest +import argparse + +from trustgraph.schema.core.topic import queue +from trustgraph.base.pubsub import get_pubsub, add_pubsub_args +from trustgraph.base.pulsar_backend import PulsarBackend + + +class TestQueueFunction: + + def test_flow_default(self): + assert queue('text-completion-request') == 'flow:tg:text-completion-request' + + def test_request_class(self): + assert queue('config', cls='request') == 'request:tg:config' + + def test_response_class(self): + assert queue('config', cls='response') == 'response:tg:config' + + def test_state_class(self): + assert queue('config', cls='state') == 'state:tg:config' + + def test_custom_topicspace(self): + assert queue('config', cls='request', topicspace='prod') == 'request:prod:config' + + def test_default_class_is_flow(self): + result = queue('something') + assert result.startswith('flow:') + + +class TestPulsarMapTopic: + + @pytest.fixture + def backend(self): + """Create a PulsarBackend without connecting.""" + b = object.__new__(PulsarBackend) + return b + + def test_flow_maps_to_persistent(self, backend): + assert backend.map_topic('flow:tg:text-completion-request') == \ + 'persistent://tg/flow/text-completion-request' + + def test_state_maps_to_persistent(self, backend): + assert backend.map_topic('state:tg:config') == \ + 'persistent://tg/state/config' + + def test_request_maps_to_non_persistent(self, backend): + assert backend.map_topic('request:tg:config') == \ + 'non-persistent://tg/request/config' + + def test_response_maps_to_non_persistent(self, backend): + assert backend.map_topic('response:tg:librarian') == \ + 'non-persistent://tg/response/librarian' + + def test_passthrough_pulsar_uri(self, backend): + uri = 'persistent://tg/flow/something' + assert backend.map_topic(uri) == uri + + def test_invalid_format_raises(self, backend): + with pytest.raises(ValueError, match="Invalid queue format"): + backend.map_topic('bad-format') + + def test_invalid_class_raises(self, backend): + with pytest.raises(ValueError, match="Invalid queue class"): + backend.map_topic('unknown:tg:topic') + + def test_custom_topicspace(self, backend): + assert backend.map_topic('flow:prod:my-queue') == \ + 'persistent://prod/flow/my-queue' + + +class TestGetPubsubDispatch: + + def test_unknown_backend_raises(self): + with pytest.raises(ValueError, match="Unknown pub/sub backend"): + get_pubsub(pubsub_backend='redis') + + +class TestAddPubsubArgs: + + def test_standalone_defaults_to_localhost(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser, standalone=True) + args = parser.parse_args([]) + assert args.pulsar_host == 'pulsar://localhost:6650' + assert args.pulsar_listener == 'localhost' + + def test_non_standalone_defaults_to_container(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser, standalone=False) + args = parser.parse_args([]) + assert 'pulsar:6650' in args.pulsar_host + assert args.pulsar_listener is None + + def test_cli_override_respected(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser, standalone=True) + args = parser.parse_args(['--pulsar-host', 'pulsar://custom:6650']) + assert args.pulsar_host == 'pulsar://custom:6650' + + def test_pubsub_backend_default(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([]) + assert args.pubsub_backend == 'pulsar' + + +class TestQueueDefinitions: + """Verify the actual queue constants produce correct names.""" + + def test_config_request(self): + from trustgraph.schema.services.config import config_request_queue + assert config_request_queue == 'request:tg:config' + + def test_config_response(self): + from trustgraph.schema.services.config import config_response_queue + assert config_response_queue == 'response:tg:config' + + def test_config_push(self): + from trustgraph.schema.services.config import config_push_queue + assert config_push_queue == 'state:tg:config' + + def test_librarian_request_is_persistent(self): + from trustgraph.schema.services.library import librarian_request_queue + assert librarian_request_queue.startswith('flow:') + + def test_knowledge_request(self): + from trustgraph.schema.knowledge.knowledge import knowledge_request_queue + assert knowledge_request_queue == 'request:tg:knowledge' diff --git a/tests/unit/test_rdf/test_rdf_wire_format.py b/tests/unit/test_rdf/test_rdf_wire_format.py index a0bbd27a..d4375462 100644 --- a/tests/unit/test_rdf/test_rdf_wire_format.py +++ b/tests/unit/test_rdf/test_rdf_wire_format.py @@ -28,21 +28,21 @@ def triple_tx(): class TestTermTranslatorIri: - def test_iri_to_pulsar(self, term_tx): + def test_iri_decode(self, term_tx): data = {"t": "i", "i": "http://example.org/Alice"} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.type == IRI assert term.iri == "http://example.org/Alice" - def test_iri_from_pulsar(self, term_tx): + def test_iri_encode(self, term_tx): term = Term(type=IRI, iri="http://example.org/Bob") - wire = term_tx.from_pulsar(term) + wire = term_tx.encode(term) assert wire == {"t": "i", "i": "http://example.org/Bob"} def test_iri_round_trip(self, term_tx): original = Term(type=IRI, iri="http://example.org/round") - wire = term_tx.from_pulsar(original) - restored = term_tx.to_pulsar(wire) + wire = term_tx.encode(original) + restored = term_tx.decode(wire) assert restored == original @@ -52,21 +52,21 @@ class TestTermTranslatorIri: class TestTermTranslatorBlank: - def test_blank_to_pulsar(self, term_tx): + def test_blank_decode(self, term_tx): data = {"t": "b", "d": "_:b42"} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.type == BLANK assert term.id == "_:b42" - def test_blank_from_pulsar(self, term_tx): + def test_blank_encode(self, term_tx): term = Term(type=BLANK, id="_:node1") - wire = term_tx.from_pulsar(term) + wire = term_tx.encode(term) assert wire == {"t": "b", "d": "_:node1"} def test_blank_round_trip(self, term_tx): original = Term(type=BLANK, id="_:x") - wire = term_tx.from_pulsar(original) - restored = term_tx.to_pulsar(wire) + wire = term_tx.encode(original) + restored = term_tx.decode(wire) assert restored == original @@ -76,29 +76,29 @@ class TestTermTranslatorBlank: class TestTermTranslatorTypedLiteral: - def test_plain_literal_to_pulsar(self, term_tx): + def test_plain_literal_decode(self, term_tx): data = {"t": "l", "v": "hello"} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.type == LITERAL assert term.value == "hello" assert term.datatype == "" assert term.language == "" - def test_xsd_integer_to_pulsar(self, term_tx): + def test_xsd_integer_decode(self, term_tx): data = { "t": "l", "v": "42", "dt": "http://www.w3.org/2001/XMLSchema#integer", } - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.value == "42" assert term.datatype.endswith("#integer") - def test_typed_literal_from_pulsar(self, term_tx): + def test_typed_literal_encode(self, term_tx): term = Term( type=LITERAL, value="3.14", datatype="http://www.w3.org/2001/XMLSchema#double", ) - wire = term_tx.from_pulsar(term) + wire = term_tx.encode(term) assert wire["t"] == "l" assert wire["v"] == "3.14" assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double" @@ -109,13 +109,13 @@ class TestTermTranslatorTypedLiteral: type=LITERAL, value="true", datatype="http://www.w3.org/2001/XMLSchema#boolean", ) - wire = term_tx.from_pulsar(original) - restored = term_tx.to_pulsar(wire) + wire = term_tx.encode(original) + restored = term_tx.decode(wire) assert restored == original def test_plain_literal_omits_dt_and_ln(self, term_tx): term = Term(type=LITERAL, value="x") - wire = term_tx.from_pulsar(term) + wire = term_tx.encode(term) assert "dt" not in wire assert "ln" not in wire @@ -126,22 +126,22 @@ class TestTermTranslatorTypedLiteral: class TestTermTranslatorLangLiteral: - def test_language_tag_to_pulsar(self, term_tx): + def test_language_tag_decode(self, term_tx): data = {"t": "l", "v": "bonjour", "ln": "fr"} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.value == "bonjour" assert term.language == "fr" - def test_language_tag_from_pulsar(self, term_tx): + def test_language_tag_encode(self, term_tx): term = Term(type=LITERAL, value="colour", language="en-GB") - wire = term_tx.from_pulsar(term) + wire = term_tx.encode(term) assert wire["ln"] == "en-GB" assert "dt" not in wire # No datatype def test_language_tag_round_trip(self, term_tx): original = Term(type=LITERAL, value="hola", language="es") - wire = term_tx.from_pulsar(original) - restored = term_tx.to_pulsar(wire) + wire = term_tx.encode(original) + restored = term_tx.decode(wire) assert restored == original @@ -151,7 +151,7 @@ class TestTermTranslatorLangLiteral: class TestTermTranslatorQuotedTriple: - def test_quoted_triple_to_pulsar(self, term_tx): + def test_quoted_triple_decode(self, term_tx): data = { "t": "t", "tr": { @@ -160,20 +160,20 @@ class TestTermTranslatorQuotedTriple: "o": {"t": "i", "i": "http://example.org/Bob"}, }, } - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.type == TRIPLE assert term.triple is not None assert term.triple.s.iri == "http://example.org/Alice" assert term.triple.o.iri == "http://example.org/Bob" - def test_quoted_triple_from_pulsar(self, term_tx): + def test_quoted_triple_encode(self, term_tx): inner = Triple( s=Term(type=IRI, iri="http://example.org/s"), p=Term(type=IRI, iri="http://example.org/p"), o=Term(type=LITERAL, value="val"), ) term = Term(type=TRIPLE, triple=inner) - wire = term_tx.from_pulsar(term) + wire = term_tx.encode(term) assert wire["t"] == "t" assert "tr" in wire assert wire["tr"]["s"]["i"] == "http://example.org/s" @@ -186,18 +186,18 @@ class TestTermTranslatorQuotedTriple: o=Term(type=LITERAL, value="C", language="en"), ) original = Term(type=TRIPLE, triple=inner) - wire = term_tx.from_pulsar(original) - restored = term_tx.to_pulsar(wire) + wire = term_tx.encode(original) + restored = term_tx.decode(wire) assert restored.type == TRIPLE assert restored.triple.s == original.triple.s assert restored.triple.o == original.triple.o def test_quoted_triple_none_triple(self, term_tx): term = Term(type=TRIPLE, triple=None) - wire = term_tx.from_pulsar(term) + wire = term_tx.encode(term) assert wire == {"t": "t"} # And back - restored = term_tx.to_pulsar(wire) + restored = term_tx.decode(wire) assert restored.type == TRIPLE assert restored.triple is None @@ -210,7 +210,7 @@ class TestTermTranslatorQuotedTriple: "o": {"t": "l", "v": "A feeling of expectation"}, }, } - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.triple.o.type == LITERAL assert term.triple.o.value == "A feeling of expectation" @@ -223,22 +223,22 @@ class TestTermTranslatorEdgeCases: def test_unknown_type(self, term_tx): data = {"t": "z"} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.type == "z" def test_empty_type(self, term_tx): data = {} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.type == "" def test_missing_iri_field(self, term_tx): data = {"t": "i"} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.iri == "" def test_missing_literal_fields(self, term_tx): data = {"t": "l"} - term = term_tx.to_pulsar(data) + term = term_tx.decode(data) assert term.value == "" assert term.datatype == "" assert term.language == "" @@ -250,24 +250,24 @@ class TestTermTranslatorEdgeCases: class TestTripleTranslator: - def test_triple_to_pulsar(self, triple_tx): + def test_triple_decode(self, triple_tx): data = { "s": {"t": "i", "i": "http://example.org/s"}, "p": {"t": "i", "i": "http://example.org/p"}, "o": {"t": "l", "v": "object"}, } - triple = triple_tx.to_pulsar(data) + triple = triple_tx.decode(data) assert triple.s.iri == "http://example.org/s" assert triple.o.value == "object" assert triple.g is None - def test_triple_from_pulsar(self, triple_tx): + def test_triple_encode(self, triple_tx): triple = Triple( s=Term(type=IRI, iri="http://example.org/A"), p=Term(type=IRI, iri="http://example.org/B"), o=Term(type=LITERAL, value="C"), ) - wire = triple_tx.from_pulsar(triple) + wire = triple_tx.encode(triple) assert wire["s"]["t"] == "i" assert wire["o"]["v"] == "C" assert "g" not in wire @@ -279,17 +279,17 @@ class TestTripleTranslator: "o": {"t": "l", "v": "val"}, "g": "urn:graph:source", } - quad = triple_tx.to_pulsar(data) + quad = triple_tx.decode(data) assert quad.g == "urn:graph:source" - def test_quad_from_pulsar_includes_graph(self, triple_tx): + def test_quad_encode_includes_graph(self, triple_tx): quad = Triple( s=Term(type=IRI, iri="http://example.org/s"), p=Term(type=IRI, iri="http://example.org/p"), o=Term(type=LITERAL, value="v"), g="urn:graph:retrieval", ) - wire = triple_tx.from_pulsar(quad) + wire = triple_tx.encode(quad) assert wire["g"] == "urn:graph:retrieval" def test_quad_round_trip(self, triple_tx): @@ -299,8 +299,8 @@ class TestTripleTranslator: o=Term(type=LITERAL, value="v"), g="urn:graph:source", ) - wire = triple_tx.from_pulsar(original) - restored = triple_tx.to_pulsar(wire) + wire = triple_tx.encode(original) + restored = triple_tx.decode(wire) assert restored == original def test_none_graph_omitted_from_wire(self, triple_tx): @@ -310,12 +310,12 @@ class TestTripleTranslator: o=Term(type=LITERAL, value="v"), g=None, ) - wire = triple_tx.from_pulsar(triple) + wire = triple_tx.encode(triple) assert "g" not in wire def test_missing_terms_handled(self, triple_tx): data = {} - triple = triple_tx.to_pulsar(data) + triple = triple_tx.decode(data) assert triple.s is None assert triple.p is None assert triple.o is None @@ -342,16 +342,16 @@ class TestSubgraphTranslator: g="urn:graph:source", ), ] - wire_list = tx.from_pulsar(triples) + wire_list = tx.encode(triples) assert len(wire_list) == 2 assert wire_list[1]["g"] == "urn:graph:source" - restored = tx.to_pulsar(wire_list) + restored = tx.decode(wire_list) assert len(restored) == 2 assert restored[0] == triples[0] assert restored[1] == triples[1] def test_empty_subgraph(self): tx = SubgraphTranslator() - assert tx.to_pulsar([]) == [] - assert tx.from_pulsar([]) == [] + assert tx.decode([]) == [] + assert tx.encode([]) == [] diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py index 2fabed58..aded7253 100644 --- a/tests/unit/test_reliability/test_metadata_preservation.py +++ b/tests/unit/test_reliability/test_metadata_preservation.py @@ -35,7 +35,7 @@ class TestDocumentMetadataTranslator: "parent-id": "doc-100", "document-type": "page", } - obj = self.tx.to_pulsar(data) + obj = self.tx.decode(data) assert obj.id == "doc-123" assert obj.time == 1710000000 assert obj.kind == "application/pdf" @@ -45,14 +45,14 @@ class TestDocumentMetadataTranslator: assert obj.parent_id == "doc-100" assert obj.document_type == "page" - wire = self.tx.from_pulsar(obj) + wire = self.tx.encode(obj) assert wire["id"] == "doc-123" assert wire["user"] == "alice" assert wire["parent-id"] == "doc-100" assert wire["document-type"] == "page" def test_defaults_for_missing_fields(self): - obj = self.tx.to_pulsar({}) + obj = self.tx.decode({}) assert obj.parent_id == "" assert obj.document_type == "source" @@ -63,25 +63,25 @@ class TestDocumentMetadataTranslator: "o": {"t": "i", "i": "http://example.org/o"}, }] data = {"metadata": triple_wire} - obj = self.tx.to_pulsar(data) + obj = self.tx.decode(data) assert len(obj.metadata) == 1 assert obj.metadata[0].s.iri == "http://example.org/s" def test_none_metadata_handled(self): data = {"metadata": None} - obj = self.tx.to_pulsar(data) + obj = self.tx.decode(data) assert obj.metadata == [] def test_empty_tags_preserved(self): data = {"tags": []} - obj = self.tx.to_pulsar(data) - wire = self.tx.from_pulsar(obj) + obj = self.tx.decode(data) + wire = self.tx.encode(obj) assert wire["tags"] == [] def test_falsy_fields_omitted_from_wire(self): """Empty string fields should be omitted from wire format.""" obj = DocumentMetadata(id="", time=0, user="") - wire = self.tx.from_pulsar(obj) + wire = self.tx.encode(obj) assert "id" not in wire assert "user" not in wire @@ -105,7 +105,7 @@ class TestProcessingMetadataTranslator: "collection": "my-collection", "tags": ["tag1"], } - obj = self.tx.to_pulsar(data) + obj = self.tx.decode(data) assert obj.id == "proc-1" assert obj.document_id == "doc-123" assert obj.flow == "default" @@ -113,32 +113,32 @@ class TestProcessingMetadataTranslator: assert obj.collection == "my-collection" assert obj.tags == ["tag1"] - wire = self.tx.from_pulsar(obj) + wire = self.tx.encode(obj) assert wire["id"] == "proc-1" assert wire["document-id"] == "doc-123" assert wire["user"] == "alice" assert wire["collection"] == "my-collection" def test_missing_fields_use_defaults(self): - obj = self.tx.to_pulsar({}) + obj = self.tx.decode({}) assert obj.id is None assert obj.user is None assert obj.collection is None def test_tags_none_omitted(self): obj = ProcessingMetadata(tags=None) - wire = self.tx.from_pulsar(obj) + wire = self.tx.encode(obj) assert "tags" not in wire def test_tags_empty_list_preserved(self): obj = ProcessingMetadata(tags=[]) - wire = self.tx.from_pulsar(obj) + wire = self.tx.encode(obj) assert wire["tags"] == [] def test_user_and_collection_preserved(self): """Core pipeline routing fields must survive round-trip.""" data = {"user": "bob", "collection": "research"} - obj = self.tx.to_pulsar(data) - wire = self.tx.from_pulsar(obj) + obj = self.tx.decode(data) + wire = self.tx.encode(obj) assert wire["user"] == "bob" assert wire["collection"] == "research" diff --git a/tests/unit/test_retrieval/test_structured_diag/test_message_translation.py b/tests/unit/test_retrieval/test_structured_diag/test_message_translation.py index 7a113250..4c6d6803 100644 --- a/tests/unit/test_retrieval/test_structured_diag/test_message_translation.py +++ b/tests/unit/test_retrieval/test_structured_diag/test_message_translation.py @@ -28,7 +28,7 @@ class TestRequestTranslation: } # Translate to Pulsar - pulsar_msg = translator.to_pulsar(api_data) + pulsar_msg = translator.decode(api_data) assert pulsar_msg.operation == "schema-selection" assert pulsar_msg.sample == "test data sample" @@ -46,7 +46,7 @@ class TestRequestTranslation: "options": {"delimiter": ","} } - pulsar_msg = translator.to_pulsar(api_data) + pulsar_msg = translator.decode(api_data) assert pulsar_msg.operation == "generate-descriptor" assert pulsar_msg.sample == "csv data" @@ -70,7 +70,7 @@ class TestResponseTranslation: ) # Translate to API format - api_data = translator.from_pulsar(pulsar_response) + api_data = translator.encode(pulsar_response) assert api_data["operation"] == "schema-selection" assert api_data["schema-matches"] == ["products", "inventory", "catalog"] @@ -86,7 +86,7 @@ class TestResponseTranslation: error=None ) - api_data = translator.from_pulsar(pulsar_response) + api_data = translator.encode(pulsar_response) assert api_data["operation"] == "schema-selection" assert api_data["schema-matches"] == [] @@ -103,7 +103,7 @@ class TestResponseTranslation: error=None ) - api_data = translator.from_pulsar(pulsar_response) + api_data = translator.encode(pulsar_response) assert api_data["operation"] == "detect-type" assert api_data["detected-type"] == "xml" @@ -123,7 +123,7 @@ class TestResponseTranslation: ) ) - api_data = translator.from_pulsar(pulsar_response) + api_data = translator.encode(pulsar_response) assert api_data["operation"] == "schema-selection" # Error objects are typically handled separately by the gateway @@ -146,7 +146,7 @@ class TestResponseTranslation: error=None ) - api_data = translator.from_pulsar(pulsar_response) + api_data = translator.encode(pulsar_response) assert api_data["operation"] == "diagnose" assert api_data["detected-type"] == "csv" @@ -165,7 +165,7 @@ class TestResponseTranslation: error=None ) - api_data, is_final = translator.from_response_with_completion(pulsar_response) + api_data, is_final = translator.encode_with_completion(pulsar_response) assert is_final is True # Structured-diag responses are always final assert api_data["operation"] == "schema-selection" diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index f9f38060..5a454279 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -1,5 +1,5 @@ -from . pubsub import PulsarClient, get_pubsub +from . pubsub import get_pubsub, add_pubsub_args from . async_processor import AsyncProcessor from . consumer import Consumer from . producer import Producer diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 8068c67d..94bab278 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -6,7 +6,6 @@ import asyncio import argparse -import _pulsar import time import uuid import logging @@ -15,7 +14,7 @@ from prometheus_client import start_http_server, Info from .. schema import ConfigPush, config_push_queue from .. log_level import LogLevel -from . pubsub import PulsarClient, get_pubsub +from . pubsub import get_pubsub, add_pubsub_args from . producer import Producer from . consumer import Consumer from . metrics import ProcessorMetrics, ConsumerMetrics @@ -223,8 +222,8 @@ class AsyncProcessor: logger.info("Keyboard interrupt.") return - except _pulsar.Interrupted: - logger.info("Pulsar Interrupted.") + except KeyboardInterrupt: + logger.info("Interrupted.") return # Exceptions from a taskgroup come in as an exception group @@ -250,15 +249,7 @@ class AsyncProcessor: @staticmethod def add_args(parser): - # Pub/sub backend selection - parser.add_argument( - '--pubsub-backend', - default=os.getenv('PUBSUB_BACKEND', 'pulsar'), - choices=['pulsar', 'mqtt'], - help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', - ) - - PulsarClient.add_args(parser) + add_pubsub_args(parser) add_logging_args(parser) parser.add_argument( diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py index 0f170030..1caeaec0 100644 --- a/trustgraph-base/trustgraph/base/flow_processor.py +++ b/trustgraph-base/trustgraph/base/flow_processor.py @@ -6,8 +6,6 @@ import json import logging -from pulsar.schema import JsonSchema - from .. schema import Error from .. schema import config_request_queue, config_response_queue from .. schema import config_push_queue diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py index a7772b67..04734f28 100644 --- a/trustgraph-base/trustgraph/base/pubsub.py +++ b/trustgraph-base/trustgraph/base/pubsub.py @@ -1,110 +1,72 @@ import os -import pulsar -import _pulsar -import uuid -from pulsar.schema import JsonSchema import logging -from .. log_level import LogLevel -from .pulsar_backend import PulsarBackend - logger = logging.getLogger(__name__) +# Default connection settings from environment +DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') +DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None) + def get_pubsub(**config): """ Factory function to create a pub/sub backend based on configuration. Args: - config: Configuration dictionary from command-line args - Must include 'pubsub_backend' key + config: Configuration dictionary from command-line args. + Key 'pubsub_backend' selects the backend (default: 'pulsar'). Returns: - Backend instance (PulsarBackend, MQTTBackend, etc.) - - Example: - backend = get_pubsub( - pubsub_backend='pulsar', - pulsar_host='pulsar://localhost:6650' - ) + Backend instance implementing the PubSubBackend protocol. """ backend_type = config.get('pubsub_backend', 'pulsar') if backend_type == 'pulsar': + from .pulsar_backend import PulsarBackend return PulsarBackend( - host=config.get('pulsar_host', PulsarClient.default_pulsar_host), - api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key), + host=config.get('pulsar_host', DEFAULT_PULSAR_HOST), + api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY), listener=config.get('pulsar_listener'), ) - elif backend_type == 'mqtt': - # TODO: Implement MQTT backend - raise NotImplementedError("MQTT backend not yet implemented") else: raise ValueError(f"Unknown pub/sub backend: {backend_type}") -class PulsarClient: +STANDALONE_PULSAR_HOST = 'pulsar://localhost:6650' - default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') - default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) - def __init__(self, **params): +def add_pubsub_args(parser, standalone=False): + """Add pub/sub CLI arguments to an argument parser. - self.client = None + Args: + parser: argparse.ArgumentParser + standalone: If True, default host is localhost (for CLI tools + that run outside containers) + """ + host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST + listener_default = 'localhost' if standalone else None - pulsar_host = params.get("pulsar_host", self.default_pulsar_host) - pulsar_listener = params.get("pulsar_listener", None) - pulsar_api_key = params.get( - "pulsar_api_key", - self.default_pulsar_api_key - ) - # Hard-code Pulsar logging to ERROR level to minimize noise + parser.add_argument( + '--pubsub-backend', + default=os.getenv('PUBSUB_BACKEND', 'pulsar'), + help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', + ) - self.pulsar_host = pulsar_host - self.pulsar_api_key = pulsar_api_key + parser.add_argument( + '-p', '--pulsar-host', + default=host, + help=f'Pulsar host (default: {host})', + ) - if pulsar_api_key: - auth = pulsar.AuthenticationToken(pulsar_api_key) - self.client = pulsar.Client( - pulsar_host, - authentication=auth, - logger=pulsar.ConsoleLogger(_pulsar.LoggerLevel.Error) - ) - else: - self.client = pulsar.Client( - pulsar_host, - listener_name=pulsar_listener, - logger=pulsar.ConsoleLogger(_pulsar.LoggerLevel.Error) - ) + parser.add_argument( + '--pulsar-api-key', + default=DEFAULT_PULSAR_API_KEY, + help='Pulsar API key', + ) - self.pulsar_listener = pulsar_listener - - def close(self): - self.client.close() - - def __del__(self): - - if hasattr(self, "client"): - if self.client: - self.client.close() - - @staticmethod - def add_args(parser): - - parser.add_argument( - '-p', '--pulsar-host', - default=__class__.default_pulsar_host, - help=f'Pulsar host (default: {__class__.default_pulsar_host})', - ) - - parser.add_argument( - '--pulsar-api-key', - default=__class__.default_pulsar_api_key, - help=f'Pulsar API key', - ) - - parser.add_argument( - '--pulsar-listener', - help=f'Pulsar listener (default: none)', - ) + parser.add_argument( + '--pulsar-listener', + default=listener_default, + help=f'Pulsar listener (default: {listener_default or "none"})', + ) diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index a3c3debd..677f2527 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -181,8 +181,11 @@ class PulsarBackendConsumer: self._schema_cls = schema_cls def receive(self, timeout_millis: int = 2000) -> Message: - """Receive a message.""" - pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis) + """Receive a message. Raises TimeoutError if no message available.""" + try: + pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis) + except _pulsar.Timeout: + raise TimeoutError("No message received within timeout") return PulsarMessage(pulsar_msg, self._schema_cls) def acknowledge(self, message: Message) -> None: @@ -237,38 +240,44 @@ class PulsarBackend: self.client = pulsar.Client(**client_args) logger.info(f"Pulsar client connected to {host}") - def map_topic(self, generic_topic: str) -> str: + def map_topic(self, queue_id: str) -> str: """ - Map generic topic format to Pulsar URI. + Map queue identifier to Pulsar URI. - Format: qos/tenant/namespace/queue - Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue + Format: class:topicspace:topic + Example: flow:tg:text-completion-request -> persistent://tg/flow/text-completion-request Args: - generic_topic: Generic topic string or already-formatted Pulsar URI + queue_id: Queue identifier string or already-formatted Pulsar URI Returns: Pulsar topic URI """ # If already a Pulsar URI, return as-is - if '://' in generic_topic: - return generic_topic + if '://' in queue_id: + return queue_id - parts = generic_topic.split('/', 3) - if len(parts) != 4: - raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue") + parts = queue_id.split(':', 2) + if len(parts) != 3: + raise ValueError( + f"Invalid queue format: {queue_id}, " + f"expected class:topicspace:topic" + ) - qos, tenant, namespace, queue = parts + cls, topicspace, topic = parts - # Map QoS to persistence - if qos == 'q0': - persistence = 'non-persistent' - elif qos in ['q1', 'q2']: + # Map class to Pulsar persistence and namespace + if cls in ('flow', 'state'): persistence = 'persistent' + elif cls in ('request', 'response'): + persistence = 'non-persistent' else: - raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2") + raise ValueError( + f"Invalid queue class: {cls}, " + f"expected flow, request, response, or state" + ) - return f"{persistence}://{tenant}/{namespace}/{queue}" + return f"{persistence}://{topicspace}/{cls}/{topic}" def create_producer(self, topic: str, schema: type, **options) -> BackendProducer: """ diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py index 17ff5a09..1cadbdd5 100644 --- a/trustgraph-base/trustgraph/clients/agent_client.py +++ b/trustgraph-base/trustgraph/clients/agent_client.py @@ -1,5 +1,4 @@ -import _pulsar from .. schema import AgentRequest, AgentResponse from .. schema import agent_request_queue @@ -7,15 +6,11 @@ from .. schema import agent_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class AgentClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -27,7 +22,6 @@ class AgentClient(BaseClient): if output_queue is None: output_queue = agent_response_queue super(AgentClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/base.py b/trustgraph-base/trustgraph/clients/base.py index 3a4da6ec..a71ba84e 100644 --- a/trustgraph-base/trustgraph/clients/base.py +++ b/trustgraph-base/trustgraph/clients/base.py @@ -1,10 +1,6 @@ -import pulsar -import _pulsar -import hashlib import uuid import time -from pulsar.schema import JsonSchema from .. exceptions import * from ..base.pubsub import get_pubsub @@ -12,16 +8,11 @@ from ..base.pubsub import get_pubsub # Default timeout for a request/response. In seconds. DEFAULT_TIMEOUT=300 -# Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class BaseClient: def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -87,7 +78,7 @@ class BaseClient: try: msg = self.consumer.receive(timeout_millis=2500) - except pulsar.exceptions.Timeout: + except TimeoutError: continue mid = msg.properties()["id"] @@ -139,4 +130,3 @@ class BaseClient: if hasattr(self, "backend"): self.backend.close() - diff --git a/trustgraph-base/trustgraph/clients/config_client.py b/trustgraph-base/trustgraph/clients/config_client.py index be2bf5b9..daadf652 100644 --- a/trustgraph-base/trustgraph/clients/config_client.py +++ b/trustgraph-base/trustgraph/clients/config_client.py @@ -1,5 +1,4 @@ -import _pulsar import json import dataclasses @@ -9,10 +8,6 @@ from .. schema import config_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug @dataclasses.dataclass class Definition: @@ -34,7 +29,7 @@ class Topic: class ConfigClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -50,7 +45,6 @@ class ConfigClient(BaseClient): output_queue = config_response_queue super(ConfigClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index 1ab47aab..ebbad397 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -1,5 +1,4 @@ -import _pulsar from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse from .. schema import document_embeddings_request_queue @@ -7,15 +6,11 @@ from .. schema import document_embeddings_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class DocumentEmbeddingsClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -30,7 +25,6 @@ class DocumentEmbeddingsClient(BaseClient): output_queue = document_embeddings_response_queue super(DocumentEmbeddingsClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 946b1a6c..057376fb 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -1,21 +1,15 @@ -import _pulsar from .. schema import DocumentRagQuery, DocumentRagResponse from .. schema import document_rag_request_queue, document_rag_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class DocumentRagClient(BaseClient): def __init__( self, - log_level=ERROR, subscriber=None, input_queue=None, output_queue=None, @@ -30,7 +24,6 @@ class DocumentRagClient(BaseClient): output_queue = document_rag_response_queue super(DocumentRagClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/embeddings_client.py b/trustgraph-base/trustgraph/clients/embeddings_client.py index 1b1c0dc8..7d9e6d8e 100644 --- a/trustgraph-base/trustgraph/clients/embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/embeddings_client.py @@ -1,20 +1,14 @@ -from pulsar.schema import JsonSchema from .. schema import EmbeddingsRequest, EmbeddingsResponse from . base import BaseClient -import _pulsar # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class EmbeddingsClient(BaseClient): def __init__( - self, log_level=ERROR, + self, input_queue=None, output_queue=None, subscriber=None, @@ -23,7 +17,6 @@ class EmbeddingsClient(BaseClient): ): super(EmbeddingsClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index f85c91ee..62a55609 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -1,5 +1,4 @@ -import _pulsar from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse from .. schema import graph_embeddings_request_queue @@ -7,15 +6,11 @@ from .. schema import graph_embeddings_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class GraphEmbeddingsClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -30,7 +25,6 @@ class GraphEmbeddingsClient(BaseClient): output_queue = graph_embeddings_response_queue super(GraphEmbeddingsClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 42ffce0c..17d7b0f0 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -1,21 +1,15 @@ -import _pulsar from .. schema import GraphRagQuery, GraphRagResponse from .. schema import graph_rag_request_queue, graph_rag_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class GraphRagClient(BaseClient): def __init__( self, - log_level=ERROR, subscriber=None, input_queue=None, output_queue=None, @@ -30,7 +24,6 @@ class GraphRagClient(BaseClient): output_queue = graph_rag_response_queue super(GraphRagClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/llm_client.py b/trustgraph-base/trustgraph/clients/llm_client.py index 3c629e7d..bfb4096c 100644 --- a/trustgraph-base/trustgraph/clients/llm_client.py +++ b/trustgraph-base/trustgraph/clients/llm_client.py @@ -1,5 +1,4 @@ -import _pulsar from .. schema import TextCompletionRequest, TextCompletionResponse from .. schema import text_completion_request_queue @@ -8,15 +7,11 @@ from . base import BaseClient from .. exceptions import LlmError # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class LlmClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -28,7 +23,6 @@ class LlmClient(BaseClient): if output_queue is None: output_queue = text_completion_response_queue super(LlmClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/prompt_client.py b/trustgraph-base/trustgraph/clients/prompt_client.py index 91707670..12c9c194 100644 --- a/trustgraph-base/trustgraph/clients/prompt_client.py +++ b/trustgraph-base/trustgraph/clients/prompt_client.py @@ -1,5 +1,4 @@ -import _pulsar import json import dataclasses @@ -9,10 +8,6 @@ from .. schema import prompt_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug @dataclasses.dataclass class Definition: @@ -34,7 +29,7 @@ class Topic: class PromptClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -49,7 +44,6 @@ class PromptClient(BaseClient): output_queue = prompt_response_queue super(PromptClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/row_embeddings_client.py b/trustgraph-base/trustgraph/clients/row_embeddings_client.py index 19d4b338..6e10de29 100644 --- a/trustgraph-base/trustgraph/clients/row_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/row_embeddings_client.py @@ -1,5 +1,4 @@ -import _pulsar from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse from .. schema import row_embeddings_request_queue @@ -7,15 +6,11 @@ from .. schema import row_embeddings_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class RowEmbeddingsClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -30,7 +25,6 @@ class RowEmbeddingsClient(BaseClient): output_queue = row_embeddings_response_queue super(RowEmbeddingsClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/clients/triples_query_client.py b/trustgraph-base/trustgraph/clients/triples_query_client.py index 401aaf0b..403d02ea 100644 --- a/trustgraph-base/trustgraph/clients/triples_query_client.py +++ b/trustgraph-base/trustgraph/clients/triples_query_client.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import _pulsar from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL from .. schema import triples_request_queue @@ -8,15 +7,11 @@ from .. schema import triples_response_queue from . base import BaseClient # Ugly -ERROR=_pulsar.LoggerLevel.Error -WARN=_pulsar.LoggerLevel.Warn -INFO=_pulsar.LoggerLevel.Info -DEBUG=_pulsar.LoggerLevel.Debug class TriplesQueryClient(BaseClient): def __init__( - self, log_level=ERROR, + self, subscriber=None, input_queue=None, output_queue=None, @@ -31,7 +26,6 @@ class TriplesQueryClient(BaseClient): output_queue = triples_response_queue super(TriplesQueryClient, self).__init__( - log_level=log_level, subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, diff --git a/trustgraph-base/trustgraph/log_level.py b/trustgraph-base/trustgraph/log_level.py index 65486b29..5b6f9e0c 100644 --- a/trustgraph-base/trustgraph/log_level.py +++ b/trustgraph-base/trustgraph/log_level.py @@ -1,6 +1,6 @@ from enum import Enum -import _pulsar + class LogLevel(Enum): DEBUG = 'debug' @@ -10,11 +10,3 @@ class LogLevel(Enum): def __str__(self): return self.value - - def to_pulsar(self): - if self == LogLevel.DEBUG: return _pulsar.LoggerLevel.Debug - if self == LogLevel.INFO: return _pulsar.LoggerLevel.Info - if self == LogLevel.WARN: return _pulsar.LoggerLevel.Warn - if self == LogLevel.ERROR: return _pulsar.LoggerLevel.Error - raise RuntimeError("Log level mismatch") - diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index b245a83e..c2c00ac2 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -6,7 +6,7 @@ from .base import MessageTranslator class AgentRequestTranslator(MessageTranslator): """Translator for AgentRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest: + def decode(self, data: Dict[str, Any]) -> AgentRequest: return AgentRequest( question=data["question"], state=data.get("state", None), @@ -26,7 +26,7 @@ class AgentRequestTranslator(MessageTranslator): expected_siblings=data.get("expected_siblings", 0), ) - def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]: + def encode(self, obj: AgentRequest) -> Dict[str, Any]: return { "question": obj.question, "state": obj.state, @@ -50,10 +50,10 @@ class AgentRequestTranslator(MessageTranslator): class AgentResponseTranslator(MessageTranslator): """Translator for AgentResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> AgentResponse: + def decode(self, data: Dict[str, Any]) -> AgentResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]: + def encode(self, obj: AgentResponse) -> Dict[str, Any]: result = {} if obj.chunk_type: @@ -81,7 +81,7 @@ class AgentResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" is_final = getattr(obj, 'end_of_dialog', False) - return self.from_pulsar(obj), is_final \ No newline at end of file + return self.encode(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/base.py b/trustgraph-base/trustgraph/messaging/translators/base.py index 64e2b635..74658082 100644 --- a/trustgraph-base/trustgraph/messaging/translators/base.py +++ b/trustgraph-base/trustgraph/messaging/translators/base.py @@ -1,43 +1,46 @@ from abc import ABC, abstractmethod from typing import Dict, Any, Tuple -from pulsar.schema import Record class Translator(ABC): - """Base class for bidirectional Pulsar ↔ dict translation""" - + """Base class for bidirectional schema ↔ dict translation. + + Translates between external API dicts (JSON from HTTP/WebSocket) + and internal schema objects (dataclasses). + """ + @abstractmethod - def to_pulsar(self, data: Dict[str, Any]) -> Record: - """Convert dict to Pulsar schema object""" + def decode(self, data: Dict[str, Any]) -> Any: + """Convert external dict to schema object.""" pass - - @abstractmethod - def from_pulsar(self, obj: Record) -> Dict[str, Any]: - """Convert Pulsar schema object to dict""" + + @abstractmethod + def encode(self, obj: Any) -> Dict[str, Any]: + """Convert schema object to external dict.""" pass class MessageTranslator(Translator): - """For complete request/response message translation""" - - def from_response_with_completion(self, obj: Record) -> Tuple[Dict[str, Any], bool]: - """Returns (response_dict, is_final) - for streaming responses""" - return self.from_pulsar(obj), True + """For complete request/response message translation.""" + + def encode_with_completion(self, obj: Any) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final) — for streaming responses.""" + return self.encode(obj), True class SendTranslator(Translator): - """For fire-and-forget send operations (like ServiceSender)""" - - def from_pulsar(self, obj: Record) -> Dict[str, Any]: - """Usually not needed for send-only operations""" - raise NotImplementedError("Send translators typically don't need from_pulsar") + """For fire-and-forget send operations.""" + + def encode(self, obj: Any) -> Dict[str, Any]: + """Usually not needed for send-only operations.""" + raise NotImplementedError("Send translators don't need encode") -def handle_optional_fields(obj: Record, fields: list) -> Dict[str, Any]: - """Helper to extract optional fields from Pulsar object""" +def handle_optional_fields(obj: Any, fields: list) -> Dict[str, Any]: + """Helper to extract optional fields from a schema object.""" result = {} for field in fields: value = getattr(obj, field, None) if value is not None: result[field] = value - return result \ No newline at end of file + return result diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index 22c82828..c6fd1500 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -6,7 +6,7 @@ from .base import MessageTranslator class CollectionManagementRequestTranslator(MessageTranslator): """Translator for CollectionManagementRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementRequest: + def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest: return CollectionManagementRequest( operation=data.get("operation"), user=data.get("user"), @@ -19,7 +19,7 @@ class CollectionManagementRequestTranslator(MessageTranslator): limit=data.get("limit") ) - def from_pulsar(self, obj: CollectionManagementRequest) -> Dict[str, Any]: + def encode(self, obj: CollectionManagementRequest) -> Dict[str, Any]: result = {} if obj.operation is not None: @@ -47,7 +47,7 @@ class CollectionManagementRequestTranslator(MessageTranslator): class CollectionManagementResponseTranslator(MessageTranslator): """Translator for CollectionManagementResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementResponse: + def decode(self, data: Dict[str, Any]) -> CollectionManagementResponse: # Handle error error = None @@ -76,7 +76,7 @@ class CollectionManagementResponseTranslator(MessageTranslator): collections=collections ) - def from_pulsar(self, obj: CollectionManagementResponse) -> Dict[str, Any]: + def encode(self, obj: CollectionManagementResponse) -> Dict[str, Any]: result = {} print("COLLECTIONMGMT", obj, flush=True) diff --git a/trustgraph-base/trustgraph/messaging/translators/config.py b/trustgraph-base/trustgraph/messaging/translators/config.py index 299c5438..e166362a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/config.py +++ b/trustgraph-base/trustgraph/messaging/translators/config.py @@ -6,7 +6,7 @@ from .base import MessageTranslator class ConfigRequestTranslator(MessageTranslator): """Translator for ConfigRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> ConfigRequest: + def decode(self, data: Dict[str, Any]) -> ConfigRequest: keys = None if "keys" in data: keys = [ @@ -35,7 +35,7 @@ class ConfigRequestTranslator(MessageTranslator): values=values ) - def from_pulsar(self, obj: ConfigRequest) -> Dict[str, Any]: + def encode(self, obj: ConfigRequest) -> Dict[str, Any]: result = {} if obj.operation is not None: @@ -69,10 +69,10 @@ class ConfigRequestTranslator(MessageTranslator): class ConfigResponseTranslator(MessageTranslator): """Translator for ConfigResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> ConfigResponse: + def decode(self, data: Dict[str, Any]) -> ConfigResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: ConfigResponse) -> Dict[str, Any]: + def encode(self, obj: ConfigResponse) -> Dict[str, Any]: result = {} if obj.version is not None: @@ -96,6 +96,6 @@ class ConfigResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: ConfigResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: ConfigResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/diagnosis.py b/trustgraph-base/trustgraph/messaging/translators/diagnosis.py index e0cb6a89..2a4b811d 100644 --- a/trustgraph-base/trustgraph/messaging/translators/diagnosis.py +++ b/trustgraph-base/trustgraph/messaging/translators/diagnosis.py @@ -7,7 +7,7 @@ from .base import MessageTranslator class StructuredDataDiagnosisRequestTranslator(MessageTranslator): """Translator for StructuredDataDiagnosisRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> StructuredDataDiagnosisRequest: + def decode(self, data: Dict[str, Any]) -> StructuredDataDiagnosisRequest: return StructuredDataDiagnosisRequest( operation=data["operation"], sample=data["sample"], @@ -16,7 +16,7 @@ class StructuredDataDiagnosisRequestTranslator(MessageTranslator): options=data.get("options", {}) ) - def from_pulsar(self, obj: StructuredDataDiagnosisRequest) -> Dict[str, Any]: + def encode(self, obj: StructuredDataDiagnosisRequest) -> Dict[str, Any]: result = { "operation": obj.operation, "sample": obj.sample, @@ -36,10 +36,10 @@ class StructuredDataDiagnosisRequestTranslator(MessageTranslator): class StructuredDataDiagnosisResponseTranslator(MessageTranslator): """Translator for StructuredDataDiagnosisResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> StructuredDataDiagnosisResponse: + def decode(self, data: Dict[str, Any]) -> StructuredDataDiagnosisResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: StructuredDataDiagnosisResponse) -> Dict[str, Any]: + def encode(self, obj: StructuredDataDiagnosisResponse) -> Dict[str, Any]: result = { "operation": obj.operation } @@ -64,6 +64,6 @@ class StructuredDataDiagnosisResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: StructuredDataDiagnosisResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: StructuredDataDiagnosisResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.encode(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index 51cda697..3e7062e2 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -30,7 +30,7 @@ def _decode_text_payload(payload: str | bytes, charset: str) -> str: class DocumentTranslator(SendTranslator): """Translator for Document schema objects (PDF docs etc.)""" - def to_pulsar(self, data: Dict[str, Any]) -> Document: + def decode(self, data: Dict[str, Any]) -> Document: # Handle base64 content validation doc = base64.b64decode(data["data"]) @@ -45,7 +45,7 @@ class DocumentTranslator(SendTranslator): data=base64.b64encode(doc).decode("utf-8") ) - def from_pulsar(self, obj: Document) -> Dict[str, Any]: + def encode(self, obj: Document) -> Dict[str, Any]: result = { "data": obj.data } @@ -69,7 +69,7 @@ class DocumentTranslator(SendTranslator): class TextDocumentTranslator(SendTranslator): """Translator for TextDocument schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> TextDocument: + def decode(self, data: Dict[str, Any]) -> TextDocument: charset = data.get("charset", "utf-8") text = _decode_text_payload(data["text"], charset) @@ -85,7 +85,7 @@ class TextDocumentTranslator(SendTranslator): text=text.encode("utf-8") ) - def from_pulsar(self, obj: TextDocument) -> Dict[str, Any]: + def encode(self, obj: TextDocument) -> Dict[str, Any]: result = { "text": obj.text.decode("utf-8") if isinstance(obj.text, bytes) else obj.text } @@ -109,7 +109,7 @@ class TextDocumentTranslator(SendTranslator): class ChunkTranslator(SendTranslator): """Translator for Chunk schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> Chunk: + def decode(self, data: Dict[str, Any]) -> Chunk: from ...schema import Metadata return Chunk( metadata=Metadata( @@ -121,7 +121,7 @@ class ChunkTranslator(SendTranslator): chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"] ) - def from_pulsar(self, obj: Chunk) -> Dict[str, Any]: + def encode(self, obj: Chunk) -> Dict[str, Any]: result = { "chunk": obj.chunk.decode("utf-8") if isinstance(obj.chunk, bytes) else obj.chunk } @@ -145,7 +145,7 @@ class ChunkTranslator(SendTranslator): class DocumentEmbeddingsTranslator(SendTranslator): """Translator for DocumentEmbeddings schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings: + def decode(self, data: Dict[str, Any]) -> DocumentEmbeddings: metadata = data.get("metadata", {}) chunks = [ @@ -167,7 +167,7 @@ class DocumentEmbeddingsTranslator(SendTranslator): chunks=chunks ) - def from_pulsar(self, obj: DocumentEmbeddings) -> Dict[str, Any]: + def encode(self, obj: DocumentEmbeddings) -> Dict[str, Any]: result = { "chunks": [ { diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings.py b/trustgraph-base/trustgraph/messaging/translators/embeddings.py index 454ce733..c3c1548a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings.py @@ -6,12 +6,12 @@ from .base import MessageTranslator class EmbeddingsRequestTranslator(MessageTranslator): """Translator for EmbeddingsRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest: + def decode(self, data: Dict[str, Any]) -> EmbeddingsRequest: return EmbeddingsRequest( texts=data["texts"] ) - def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]: + def encode(self, obj: EmbeddingsRequest) -> Dict[str, Any]: return { "texts": obj.texts } @@ -20,14 +20,14 @@ class EmbeddingsRequestTranslator(MessageTranslator): class EmbeddingsResponseTranslator(MessageTranslator): """Translator for EmbeddingsResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsResponse: + def decode(self, data: Dict[str, Any]) -> EmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: EmbeddingsResponse) -> Dict[str, Any]: + def encode(self, obj: EmbeddingsResponse) -> Dict[str, Any]: return { "vectors": obj.vectors } - def from_response_with_completion(self, obj: EmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: EmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.encode(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index f10ca4c6..fce1625e 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -11,7 +11,7 @@ from .primitives import ValueTranslator class DocumentEmbeddingsRequestTranslator(MessageTranslator): """Translator for DocumentEmbeddingsRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: + def decode(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: return DocumentEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), @@ -19,7 +19,7 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): collection=data.get("collection", "default") ) - def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: + def encode(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: return { "vector": obj.vector, "limit": obj.limit, @@ -31,10 +31,10 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): class DocumentEmbeddingsResponseTranslator(MessageTranslator): """Translator for DocumentEmbeddingsResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: + def decode(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: + def encode(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: result = {} if obj.chunks is not None: @@ -48,15 +48,15 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.encode(obj), True class GraphEmbeddingsRequestTranslator(MessageTranslator): """Translator for GraphEmbeddingsRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: + def decode(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: return GraphEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), @@ -64,7 +64,7 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): collection=data.get("collection", "default") ) - def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: + def encode(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: return { "vector": obj.vector, "limit": obj.limit, @@ -79,16 +79,16 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator): def __init__(self): self.value_translator = ValueTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: + def decode(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: + def encode(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: result = {} if obj.entities is not None: result["entities"] = [ { - "entity": self.value_translator.from_pulsar(match.entity), + "entity": self.value_translator.encode(match.entity), "score": match.score } for match in obj.entities @@ -96,15 +96,15 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.encode(obj), True class RowEmbeddingsRequestTranslator(MessageTranslator): """Translator for RowEmbeddingsRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: + def decode(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: return RowEmbeddingsRequest( vector=data["vector"], limit=int(data.get("limit", 10)), @@ -114,7 +114,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): index_name=data.get("index_name") ) - def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: + def encode(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: result = { "vector": obj.vector, "limit": obj.limit, @@ -130,10 +130,10 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): class RowEmbeddingsResponseTranslator(MessageTranslator): """Translator for RowEmbeddingsResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse: + def decode(self, data: Dict[str, Any]) -> RowEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]: + def encode(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]: result = {} if obj.error is not None: @@ -155,6 +155,6 @@ class RowEmbeddingsResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/flow.py b/trustgraph-base/trustgraph/messaging/translators/flow.py index 542b65ec..2047475e 100644 --- a/trustgraph-base/trustgraph/messaging/translators/flow.py +++ b/trustgraph-base/trustgraph/messaging/translators/flow.py @@ -6,7 +6,7 @@ from .base import MessageTranslator class FlowRequestTranslator(MessageTranslator): """Translator for FlowRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> FlowRequest: + def decode(self, data: Dict[str, Any]) -> FlowRequest: return FlowRequest( operation=data.get("operation"), blueprint_name=data.get("blueprint-name"), @@ -16,7 +16,7 @@ class FlowRequestTranslator(MessageTranslator): parameters=data.get("parameters") ) - def from_pulsar(self, obj: FlowRequest) -> Dict[str, Any]: + def encode(self, obj: FlowRequest) -> Dict[str, Any]: result = {} if obj.operation is not None: @@ -38,10 +38,10 @@ class FlowRequestTranslator(MessageTranslator): class FlowResponseTranslator(MessageTranslator): """Translator for FlowResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> FlowResponse: + def decode(self, data: Dict[str, Any]) -> FlowResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: FlowResponse) -> Dict[str, Any]: + def encode(self, obj: FlowResponse) -> Dict[str, Any]: result = {} if obj.blueprint_names is not None: @@ -59,6 +59,6 @@ class FlowResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index 0043d1e4..2f11d75a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -14,7 +14,7 @@ class KnowledgeRequestTranslator(MessageTranslator): self.value_translator = ValueTranslator() self.subgraph_translator = SubgraphTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeRequest: + def decode(self, data: Dict[str, Any]) -> KnowledgeRequest: triples = None if "triples" in data: triples = Triples( @@ -24,7 +24,7 @@ class KnowledgeRequestTranslator(MessageTranslator): user=data["triples"]["metadata"]["user"], collection=data["triples"]["metadata"]["collection"] ), - triples=self.subgraph_translator.to_pulsar(data["triples"]["triples"]), + triples=self.subgraph_translator.decode(data["triples"]["triples"]), ) graph_embeddings = None @@ -38,7 +38,7 @@ class KnowledgeRequestTranslator(MessageTranslator): ), entities=[ EntityEmbeddings( - entity=self.value_translator.to_pulsar(ent["entity"]), + entity=self.value_translator.decode(ent["entity"]), vectors=ent["vectors"], ) for ent in data["graph-embeddings"]["entities"] @@ -55,7 +55,7 @@ class KnowledgeRequestTranslator(MessageTranslator): graph_embeddings=graph_embeddings, ) - def from_pulsar(self, obj: KnowledgeRequest) -> Dict[str, Any]: + def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]: result = {} if obj.operation: @@ -77,7 +77,7 @@ class KnowledgeRequestTranslator(MessageTranslator): "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, - "triples": self.subgraph_translator.from_pulsar(obj.triples.triples), + "triples": self.subgraph_translator.encode(obj.triples.triples), } if obj.graph_embeddings: @@ -91,7 +91,7 @@ class KnowledgeRequestTranslator(MessageTranslator): "entities": [ { "vector": entity.vector, - "entity": self.value_translator.from_pulsar(entity.entity), + "entity": self.value_translator.encode(entity.entity), } for entity in obj.graph_embeddings.entities ], @@ -107,10 +107,10 @@ class KnowledgeResponseTranslator(MessageTranslator): self.value_translator = ValueTranslator() self.subgraph_translator = SubgraphTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeResponse: + def decode(self, data: Dict[str, Any]) -> KnowledgeResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: KnowledgeResponse) -> Dict[str, Any]: + def encode(self, obj: KnowledgeResponse) -> Dict[str, Any]: # Response to list operation if obj.ids is not None: return {"ids": obj.ids} @@ -125,7 +125,7 @@ class KnowledgeResponseTranslator(MessageTranslator): "user": obj.triples.metadata.user, "collection": obj.triples.metadata.collection, }, - "triples": self.subgraph_translator.from_pulsar(obj.triples.triples), + "triples": self.subgraph_translator.encode(obj.triples.triples), } } @@ -142,7 +142,7 @@ class KnowledgeResponseTranslator(MessageTranslator): "entities": [ { "vector": entity.vector, - "entity": self.value_translator.from_pulsar(entity.entity), + "entity": self.value_translator.encode(entity.entity), } for entity in obj.graph_embeddings.entities ], @@ -156,9 +156,9 @@ class KnowledgeResponseTranslator(MessageTranslator): # Empty response (successful delete) return {} - def from_response_with_completion(self, obj: KnowledgeResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: KnowledgeResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - response = self.from_pulsar(obj) + response = self.encode(obj) # Check if this is a final response is_final = ( diff --git a/trustgraph-base/trustgraph/messaging/translators/library.py b/trustgraph-base/trustgraph/messaging/translators/library.py index c7e849aa..7c77c39c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/library.py +++ b/trustgraph-base/trustgraph/messaging/translators/library.py @@ -11,16 +11,16 @@ class LibraryRequestTranslator(MessageTranslator): self.doc_metadata_translator = DocumentMetadataTranslator() self.proc_metadata_translator = ProcessingMetadataTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> LibrarianRequest: + def decode(self, data: Dict[str, Any]) -> LibrarianRequest: # Document metadata doc_metadata = None if "document-metadata" in data: - doc_metadata = self.doc_metadata_translator.to_pulsar(data["document-metadata"]) + doc_metadata = self.doc_metadata_translator.decode(data["document-metadata"]) # Processing metadata proc_metadata = None if "processing-metadata" in data: - proc_metadata = self.proc_metadata_translator.to_pulsar(data["processing-metadata"]) + proc_metadata = self.proc_metadata_translator.decode(data["processing-metadata"]) # Criteria criteria = [] @@ -61,7 +61,7 @@ class LibraryRequestTranslator(MessageTranslator): include_children=data.get("include-children", False), ) - def from_pulsar(self, obj: LibrarianRequest) -> Dict[str, Any]: + def encode(self, obj: LibrarianRequest) -> Dict[str, Any]: result = {} if obj.operation: @@ -71,9 +71,9 @@ class LibraryRequestTranslator(MessageTranslator): if obj.processing_id: result["processing-id"] = obj.processing_id if obj.document_metadata: - result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata) + result["document-metadata"] = self.doc_metadata_translator.encode(obj.document_metadata) if obj.processing_metadata: - result["processing-metadata"] = self.proc_metadata_translator.from_pulsar(obj.processing_metadata) + result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata) if obj.content: result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content if obj.user: @@ -100,10 +100,10 @@ class LibraryResponseTranslator(MessageTranslator): self.doc_metadata_translator = DocumentMetadataTranslator() self.proc_metadata_translator = ProcessingMetadataTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> LibrarianResponse: + def decode(self, data: Dict[str, Any]) -> LibrarianResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: LibrarianResponse) -> Dict[str, Any]: + def encode(self, obj: LibrarianResponse) -> Dict[str, Any]: result = {} if obj.error: @@ -113,20 +113,20 @@ class LibraryResponseTranslator(MessageTranslator): } if obj.document_metadata: - result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata) + result["document-metadata"] = self.doc_metadata_translator.encode(obj.document_metadata) if obj.content: result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content if obj.document_metadatas is not None: result["document-metadatas"] = [ - self.doc_metadata_translator.from_pulsar(dm) + self.doc_metadata_translator.encode(dm) for dm in obj.document_metadatas ] if obj.processing_metadatas is not None: result["processing-metadatas"] = [ - self.proc_metadata_translator.from_pulsar(pm) + self.proc_metadata_translator.encode(pm) for pm in obj.processing_metadatas ] @@ -172,6 +172,6 @@ class LibraryResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: LibrarianResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: LibrarianResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), obj.is_final + return self.encode(obj), obj.is_final diff --git a/trustgraph-base/trustgraph/messaging/translators/metadata.py b/trustgraph-base/trustgraph/messaging/translators/metadata.py index 46a28d0a..3e141c19 100644 --- a/trustgraph-base/trustgraph/messaging/translators/metadata.py +++ b/trustgraph-base/trustgraph/messaging/translators/metadata.py @@ -10,7 +10,7 @@ class DocumentMetadataTranslator(Translator): def __init__(self): self.subgraph_translator = SubgraphTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> DocumentMetadata: + def decode(self, data: Dict[str, Any]) -> DocumentMetadata: metadata = data.get("metadata", []) return DocumentMetadata( id=data.get("id"), @@ -18,14 +18,14 @@ class DocumentMetadataTranslator(Translator): kind=data.get("kind"), title=data.get("title"), comments=data.get("comments"), - metadata=self.subgraph_translator.to_pulsar(metadata) if metadata is not None else [], + metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [], user=data.get("user"), tags=data.get("tags"), parent_id=data.get("parent-id", ""), document_type=data.get("document-type", "source"), ) - def from_pulsar(self, obj: DocumentMetadata) -> Dict[str, Any]: + def encode(self, obj: DocumentMetadata) -> Dict[str, Any]: result = {} if obj.id: @@ -39,7 +39,7 @@ class DocumentMetadataTranslator(Translator): if obj.comments: result["comments"] = obj.comments if obj.metadata is not None: - result["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata) + result["metadata"] = self.subgraph_translator.encode(obj.metadata) if obj.user: result["user"] = obj.user if obj.tags is not None: @@ -55,7 +55,7 @@ class DocumentMetadataTranslator(Translator): class ProcessingMetadataTranslator(Translator): """Translator for ProcessingMetadata schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> ProcessingMetadata: + def decode(self, data: Dict[str, Any]) -> ProcessingMetadata: return ProcessingMetadata( id=data.get("id"), document_id=data.get("document-id"), @@ -66,7 +66,7 @@ class ProcessingMetadataTranslator(Translator): tags=data.get("tags") ) - def from_pulsar(self, obj: ProcessingMetadata) -> Dict[str, Any]: + def encode(self, obj: ProcessingMetadata) -> Dict[str, Any]: result = {} if obj.id: diff --git a/trustgraph-base/trustgraph/messaging/translators/nlp_query.py b/trustgraph-base/trustgraph/messaging/translators/nlp_query.py index 2c445579..c1f016b8 100644 --- a/trustgraph-base/trustgraph/messaging/translators/nlp_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/nlp_query.py @@ -6,13 +6,13 @@ from .base import MessageTranslator class QuestionToStructuredQueryRequestTranslator(MessageTranslator): """Translator for QuestionToStructuredQueryRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> QuestionToStructuredQueryRequest: + def decode(self, data: Dict[str, Any]) -> QuestionToStructuredQueryRequest: return QuestionToStructuredQueryRequest( question=data.get("question", ""), max_results=data.get("max_results", 100) ) - def from_pulsar(self, obj: QuestionToStructuredQueryRequest) -> Dict[str, Any]: + def encode(self, obj: QuestionToStructuredQueryRequest) -> Dict[str, Any]: return { "question": obj.question, "max_results": obj.max_results @@ -22,10 +22,10 @@ class QuestionToStructuredQueryRequestTranslator(MessageTranslator): class QuestionToStructuredQueryResponseTranslator(MessageTranslator): """Translator for QuestionToStructuredQueryResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> QuestionToStructuredQueryResponse: + def decode(self, data: Dict[str, Any]) -> QuestionToStructuredQueryResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: QuestionToStructuredQueryResponse) -> Dict[str, Any]: + def encode(self, obj: QuestionToStructuredQueryResponse) -> Dict[str, Any]: result = { "graphql_query": obj.graphql_query, "variables": dict(obj.variables) if obj.variables else {}, @@ -42,6 +42,6 @@ class QuestionToStructuredQueryResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: QuestionToStructuredQueryResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: QuestionToStructuredQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.encode(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/primitives.py b/trustgraph-base/trustgraph/messaging/translators/primitives.py index d54efc49..7eb3d285 100644 --- a/trustgraph-base/trustgraph/messaging/translators/primitives.py +++ b/trustgraph-base/trustgraph/messaging/translators/primitives.py @@ -17,7 +17,7 @@ class TermTranslator(Translator): - "tr": triple (for TRIPLE type, nested) """ - def to_pulsar(self, data: Dict[str, Any]) -> Term: + def decode(self, data: Dict[str, Any]) -> Term: term_type = data.get("t", "") if term_type == IRI: @@ -38,7 +38,7 @@ class TermTranslator(Translator): # Nested triple - use TripleTranslator triple_data = data.get("tr") if triple_data: - triple = _triple_translator_to_pulsar(triple_data) + triple = _triple_translator_decode(triple_data) else: triple = None return Term(type=TRIPLE, triple=triple) @@ -47,7 +47,7 @@ class TermTranslator(Translator): # Unknown or empty type return Term(type=term_type) - def from_pulsar(self, obj: Term) -> Dict[str, Any]: + def encode(self, obj: Term) -> Dict[str, Any]: result: Dict[str, Any] = {"t": obj.type} if obj.type == IRI: @@ -65,33 +65,33 @@ class TermTranslator(Translator): elif obj.type == TRIPLE: if obj.triple: - result["tr"] = _triple_translator_from_pulsar(obj.triple) + result["tr"] = _triple_translator_encode(obj.triple) return result # Module-level helper functions to avoid circular instantiation -def _triple_translator_to_pulsar(data: Dict[str, Any]) -> Triple: +def _triple_translator_decode(data: Dict[str, Any]) -> Triple: term_translator = TermTranslator() return Triple( - s=term_translator.to_pulsar(data["s"]) if data.get("s") else None, - p=term_translator.to_pulsar(data["p"]) if data.get("p") else None, - o=term_translator.to_pulsar(data["o"]) if data.get("o") else None, + s=term_translator.decode(data["s"]) if data.get("s") else None, + p=term_translator.decode(data["p"]) if data.get("p") else None, + o=term_translator.decode(data["o"]) if data.get("o") else None, g=data.get("g"), ) -def _triple_translator_from_pulsar(obj: Triple) -> Dict[str, Any]: +def _triple_translator_encode(obj: Triple) -> Dict[str, Any]: """Convert Triple object to wire format dict.""" term_translator = TermTranslator() result: Dict[str, Any] = {} if obj.s: - result["s"] = term_translator.from_pulsar(obj.s) + result["s"] = term_translator.encode(obj.s) if obj.p: - result["p"] = term_translator.from_pulsar(obj.p) + result["p"] = term_translator.encode(obj.p) if obj.o: - result["o"] = term_translator.from_pulsar(obj.o) + result["o"] = term_translator.encode(obj.o) if obj.g: result["g"] = obj.g @@ -104,23 +104,23 @@ class TripleTranslator(Translator): def __init__(self): self.term_translator = TermTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> Triple: + def decode(self, data: Dict[str, Any]) -> Triple: return Triple( - s=self.term_translator.to_pulsar(data["s"]) if data.get("s") else None, - p=self.term_translator.to_pulsar(data["p"]) if data.get("p") else None, - o=self.term_translator.to_pulsar(data["o"]) if data.get("o") else None, + s=self.term_translator.decode(data["s"]) if data.get("s") else None, + p=self.term_translator.decode(data["p"]) if data.get("p") else None, + o=self.term_translator.decode(data["o"]) if data.get("o") else None, g=data.get("g"), ) - def from_pulsar(self, obj: Triple) -> Dict[str, Any]: + def encode(self, obj: Triple) -> Dict[str, Any]: result: Dict[str, Any] = {} if obj.s: - result["s"] = self.term_translator.from_pulsar(obj.s) + result["s"] = self.term_translator.encode(obj.s) if obj.p: - result["p"] = self.term_translator.from_pulsar(obj.p) + result["p"] = self.term_translator.encode(obj.p) if obj.o: - result["o"] = self.term_translator.from_pulsar(obj.o) + result["o"] = self.term_translator.encode(obj.o) if obj.g: result["g"] = obj.g @@ -137,17 +137,17 @@ class SubgraphTranslator(Translator): def __init__(self): self.triple_translator = TripleTranslator() - def to_pulsar(self, data: List[Dict[str, Any]]) -> List[Triple]: - return [self.triple_translator.to_pulsar(t) for t in data] + def decode(self, data: List[Dict[str, Any]]) -> List[Triple]: + return [self.triple_translator.decode(t) for t in data] - def from_pulsar(self, obj: List[Triple]) -> List[Dict[str, Any]]: - return [self.triple_translator.from_pulsar(t) for t in obj] + def encode(self, obj: List[Triple]) -> List[Dict[str, Any]]: + return [self.triple_translator.encode(t) for t in obj] class RowSchemaTranslator(Translator): """Translator for RowSchema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> RowSchema: + def decode(self, data: Dict[str, Any]) -> RowSchema: """Convert dict to RowSchema Pulsar object""" fields = [] for field_data in data.get("fields", []): @@ -169,7 +169,7 @@ class RowSchemaTranslator(Translator): fields=fields ) - def from_pulsar(self, obj: RowSchema) -> Dict[str, Any]: + def encode(self, obj: RowSchema) -> Dict[str, Any]: """Convert RowSchema Pulsar object to JSON-serializable dictionary""" result = { "name": obj.name, @@ -200,7 +200,7 @@ class RowSchemaTranslator(Translator): class FieldTranslator(Translator): """Translator for Field objects""" - def to_pulsar(self, data: Dict[str, Any]) -> Field: + def decode(self, data: Dict[str, Any]) -> Field: """Convert dict to Field Pulsar object""" return Field( name=data.get("name", ""), @@ -213,7 +213,7 @@ class FieldTranslator(Translator): enum_values=data.get("enum_values", []) ) - def from_pulsar(self, obj: Field) -> Dict[str, Any]: + def encode(self, obj: Field) -> Dict[str, Any]: """Convert Field Pulsar object to JSON-serializable dictionary""" result = { "name": obj.name, diff --git a/trustgraph-base/trustgraph/messaging/translators/prompt.py b/trustgraph-base/trustgraph/messaging/translators/prompt.py index 5ff99fdc..4345e6fd 100644 --- a/trustgraph-base/trustgraph/messaging/translators/prompt.py +++ b/trustgraph-base/trustgraph/messaging/translators/prompt.py @@ -7,7 +7,7 @@ from .base import MessageTranslator class PromptRequestTranslator(MessageTranslator): """Translator for PromptRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> PromptRequest: + def decode(self, data: Dict[str, Any]) -> PromptRequest: # Handle both "terms" and "variables" input keys terms = data.get("terms", {}) if "variables" in data: @@ -23,7 +23,7 @@ class PromptRequestTranslator(MessageTranslator): streaming=data.get("streaming", False) ) - def from_pulsar(self, obj: PromptRequest) -> Dict[str, Any]: + def encode(self, obj: PromptRequest) -> Dict[str, Any]: result = {} if obj.id: @@ -37,10 +37,10 @@ class PromptRequestTranslator(MessageTranslator): class PromptResponseTranslator(MessageTranslator): """Translator for PromptResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> PromptResponse: + def decode(self, data: Dict[str, Any]) -> PromptResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]: + def encode(self, obj: PromptResponse) -> Dict[str, Any]: result = {} # Include text field if present (even if empty string) @@ -55,8 +55,8 @@ class PromptResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" # Check end_of_stream field to determine if this is the final message is_final = getattr(obj, 'end_of_stream', True) - return self.from_pulsar(obj), is_final \ No newline at end of file + return self.encode(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 98473db2..7e2abfa1 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -6,7 +6,7 @@ from .base import MessageTranslator class DocumentRagRequestTranslator(MessageTranslator): """Translator for DocumentRagQuery schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery: + def decode(self, data: Dict[str, Any]) -> DocumentRagQuery: return DocumentRagQuery( query=data["query"], user=data.get("user", "trustgraph"), @@ -15,7 +15,7 @@ class DocumentRagRequestTranslator(MessageTranslator): streaming=data.get("streaming", False) ) - def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]: + def encode(self, obj: DocumentRagQuery) -> Dict[str, Any]: return { "query": obj.query, "user": obj.user, @@ -28,10 +28,10 @@ class DocumentRagRequestTranslator(MessageTranslator): class DocumentRagResponseTranslator(MessageTranslator): """Translator for DocumentRagResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse: + def decode(self, data: Dict[str, Any]) -> DocumentRagResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: + def encode(self, obj: DocumentRagResponse) -> Dict[str, Any]: result = {} # Include message_type for distinguishing chunk vs explain messages @@ -65,17 +65,17 @@ class DocumentRagResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" # Session is complete when end_of_session is True is_final = getattr(obj, 'end_of_session', False) - return self.from_pulsar(obj), is_final + return self.encode(obj), is_final class GraphRagRequestTranslator(MessageTranslator): """Translator for GraphRagQuery schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery: + def decode(self, data: Dict[str, Any]) -> GraphRagQuery: return GraphRagQuery( query=data["query"], user=data.get("user", "trustgraph"), @@ -89,7 +89,7 @@ class GraphRagRequestTranslator(MessageTranslator): streaming=data.get("streaming", False) ) - def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]: + def encode(self, obj: GraphRagQuery) -> Dict[str, Any]: return { "query": obj.query, "user": obj.user, @@ -107,10 +107,10 @@ class GraphRagRequestTranslator(MessageTranslator): class GraphRagResponseTranslator(MessageTranslator): """Translator for GraphRagResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse: + def decode(self, data: Dict[str, Any]) -> GraphRagResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: + def encode(self, obj: GraphRagResponse) -> Dict[str, Any]: result = {} # Include message_type @@ -144,8 +144,8 @@ class GraphRagResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" # Session is complete when end_of_session is True is_final = getattr(obj, 'end_of_session', False) - return self.from_pulsar(obj), is_final \ No newline at end of file + return self.encode(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/rows_query.py b/trustgraph-base/trustgraph/messaging/translators/rows_query.py index 6feb75a3..6153901c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/rows_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/rows_query.py @@ -7,7 +7,7 @@ import json class RowsQueryRequestTranslator(MessageTranslator): """Translator for RowsQueryRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryRequest: + def decode(self, data: Dict[str, Any]) -> RowsQueryRequest: return RowsQueryRequest( user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), @@ -16,7 +16,7 @@ class RowsQueryRequestTranslator(MessageTranslator): operation_name=data.get("operation_name", None) ) - def from_pulsar(self, obj: RowsQueryRequest) -> Dict[str, Any]: + def encode(self, obj: RowsQueryRequest) -> Dict[str, Any]: result = { "user": obj.user, "collection": obj.collection, @@ -33,10 +33,10 @@ class RowsQueryRequestTranslator(MessageTranslator): class RowsQueryResponseTranslator(MessageTranslator): """Translator for RowsQueryResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryResponse: + def decode(self, data: Dict[str, Any]) -> RowsQueryResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: RowsQueryResponse) -> Dict[str, Any]: + def encode(self, obj: RowsQueryResponse) -> Dict[str, Any]: result = {} # Handle GraphQL response data @@ -74,6 +74,6 @@ class RowsQueryResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: RowsQueryResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: RowsQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/structured_query.py b/trustgraph-base/trustgraph/messaging/translators/structured_query.py index cc3ae80c..6b0b38a1 100644 --- a/trustgraph-base/trustgraph/messaging/translators/structured_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/structured_query.py @@ -7,14 +7,14 @@ import json class StructuredQueryRequestTranslator(MessageTranslator): """Translator for StructuredQueryRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryRequest: + def decode(self, data: Dict[str, Any]) -> StructuredQueryRequest: return StructuredQueryRequest( question=data.get("question", ""), user=data.get("user", "trustgraph"), # Default fallback collection=data.get("collection", "default") # Default fallback ) - def from_pulsar(self, obj: StructuredQueryRequest) -> Dict[str, Any]: + def encode(self, obj: StructuredQueryRequest) -> Dict[str, Any]: return { "question": obj.question, "user": obj.user, @@ -25,10 +25,10 @@ class StructuredQueryRequestTranslator(MessageTranslator): class StructuredQueryResponseTranslator(MessageTranslator): """Translator for StructuredQueryResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> StructuredQueryResponse: + def decode(self, data: Dict[str, Any]) -> StructuredQueryResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: StructuredQueryResponse) -> Dict[str, Any]: + def encode(self, obj: StructuredQueryResponse) -> Dict[str, Any]: result = {} # Handle structured query response data @@ -55,6 +55,6 @@ class StructuredQueryResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: StructuredQueryResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: StructuredQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.encode(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/text_completion.py b/trustgraph-base/trustgraph/messaging/translators/text_completion.py index fa3749b5..596ff744 100644 --- a/trustgraph-base/trustgraph/messaging/translators/text_completion.py +++ b/trustgraph-base/trustgraph/messaging/translators/text_completion.py @@ -6,14 +6,14 @@ from .base import MessageTranslator class TextCompletionRequestTranslator(MessageTranslator): """Translator for TextCompletionRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionRequest: + def decode(self, data: Dict[str, Any]) -> TextCompletionRequest: return TextCompletionRequest( system=data["system"], prompt=data["prompt"], streaming=data.get("streaming", False) ) - def from_pulsar(self, obj: TextCompletionRequest) -> Dict[str, Any]: + def encode(self, obj: TextCompletionRequest) -> Dict[str, Any]: return { "system": obj.system, "prompt": obj.prompt @@ -23,10 +23,10 @@ class TextCompletionRequestTranslator(MessageTranslator): class TextCompletionResponseTranslator(MessageTranslator): """Translator for TextCompletionResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionResponse: + def decode(self, data: Dict[str, Any]) -> TextCompletionResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]: + def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]: result = {"response": obj.response} if obj.in_token: @@ -41,8 +41,8 @@ class TextCompletionResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" # Check end_of_stream field to determine if this is the final message is_final = getattr(obj, 'end_of_stream', True) - return self.from_pulsar(obj), is_final \ No newline at end of file + return self.encode(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/tool.py b/trustgraph-base/trustgraph/messaging/translators/tool.py index 9f4d05cc..0651ca03 100644 --- a/trustgraph-base/trustgraph/messaging/translators/tool.py +++ b/trustgraph-base/trustgraph/messaging/translators/tool.py @@ -6,7 +6,7 @@ from .base import MessageTranslator class ToolRequestTranslator(MessageTranslator): """Translator for ToolRequest schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> ToolRequest: + def decode(self, data: Dict[str, Any]) -> ToolRequest: # Handle both "name" and "parameters" input keys name = data.get("name", "") if "parameters" in data: @@ -19,7 +19,7 @@ class ToolRequestTranslator(MessageTranslator): parameters = parameters, ) - def from_pulsar(self, obj: ToolRequest) -> Dict[str, Any]: + def encode(self, obj: ToolRequest) -> Dict[str, Any]: result = {} if obj.name: @@ -32,10 +32,10 @@ class ToolRequestTranslator(MessageTranslator): class ToolResponseTranslator(MessageTranslator): """Translator for ToolResponse schema objects""" - def to_pulsar(self, data: Dict[str, Any]) -> ToolResponse: + def decode(self, data: Dict[str, Any]) -> ToolResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: ToolResponse) -> Dict[str, Any]: + def encode(self, obj: ToolResponse) -> Dict[str, Any]: result = {} @@ -46,6 +46,6 @@ class ToolResponseTranslator(MessageTranslator): return result - def from_response_with_completion(self, obj: ToolResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: ToolResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/triples.py b/trustgraph-base/trustgraph/messaging/translators/triples.py index 2f29aa56..21d2698f 100644 --- a/trustgraph-base/trustgraph/messaging/translators/triples.py +++ b/trustgraph-base/trustgraph/messaging/translators/triples.py @@ -10,10 +10,10 @@ class TriplesQueryRequestTranslator(MessageTranslator): def __init__(self): self.value_translator = ValueTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryRequest: - s = self.value_translator.to_pulsar(data["s"]) if "s" in data else None - p = self.value_translator.to_pulsar(data["p"]) if "p" in data else None - o = self.value_translator.to_pulsar(data["o"]) if "o" in data else None + def decode(self, data: Dict[str, Any]) -> TriplesQueryRequest: + s = self.value_translator.decode(data["s"]) if "s" in data else None + p = self.value_translator.decode(data["p"]) if "p" in data else None + o = self.value_translator.decode(data["o"]) if "o" in data else None g = data.get("g") # None=default graph, "*"=all graphs return TriplesQueryRequest( @@ -28,7 +28,7 @@ class TriplesQueryRequestTranslator(MessageTranslator): batch_size=int(data.get("batch-size", 20)), ) - def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]: + def encode(self, obj: TriplesQueryRequest) -> Dict[str, Any]: result = { "limit": obj.limit, "user": obj.user, @@ -38,11 +38,11 @@ class TriplesQueryRequestTranslator(MessageTranslator): } if obj.s: - result["s"] = self.value_translator.from_pulsar(obj.s) + result["s"] = self.value_translator.encode(obj.s) if obj.p: - result["p"] = self.value_translator.from_pulsar(obj.p) + result["p"] = self.value_translator.encode(obj.p) if obj.o: - result["o"] = self.value_translator.from_pulsar(obj.o) + result["o"] = self.value_translator.encode(obj.o) if obj.g is not None: result["g"] = obj.g @@ -55,14 +55,14 @@ class TriplesQueryResponseTranslator(MessageTranslator): def __init__(self): self.subgraph_translator = SubgraphTranslator() - def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryResponse: + def decode(self, data: Dict[str, Any]) -> TriplesQueryResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - def from_pulsar(self, obj: TriplesQueryResponse) -> Dict[str, Any]: + def encode(self, obj: TriplesQueryResponse) -> Dict[str, Any]: return { - "response": self.subgraph_translator.from_pulsar(obj.triples) + "response": self.subgraph_translator.encode(obj.triples) } - def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]: + def encode_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), obj.is_final \ No newline at end of file + return self.encode(obj), obj.is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/core/topic.py b/trustgraph-base/trustgraph/schema/core/topic.py index 09c633e4..036ea142 100644 --- a/trustgraph-base/trustgraph/schema/core/topic.py +++ b/trustgraph-base/trustgraph/schema/core/topic.py @@ -1,23 +1,26 @@ -def topic(queue_name, qos='q1', tenant='tg', namespace='flow'): +def queue(topic, cls='flow', topicspace='tg'): """ - Create a generic topic identifier that can be mapped by backends. + Create a queue identifier in CLASS:TOPICSPACE:TOPIC format. Args: - queue_name: The queue/topic name - qos: Quality of service - - 'q0' = best-effort (no ack) - - 'q1' = at-least-once (ack required) - - 'q2' = exactly-once (two-phase ack) - tenant: Tenant identifier for multi-tenancy - namespace: Namespace within tenant + topic: The logical queue name (e.g. 'config', 'librarian') + cls: Queue class determining operational characteristics: + - 'flow' = persistent processing pipeline queue + - 'request' = non-persistent, short TTL request queue + - 'response' = non-persistent, short TTL response queue + - 'state' = persistent, last-value state broadcast + topicspace: Deployment isolation prefix (default: 'tg') Returns: - Generic topic string: qos/tenant/namespace/queue_name + Queue identifier string: cls:topicspace:topic Examples: - topic('my-queue') # q1/tg/flow/my-queue - topic('config', qos='q2', namespace='config') # q2/tg/config/config + queue('text-completion-request') + # flow:tg:text-completion-request + queue('config', cls='request') + # request:tg:config + queue('config', cls='state') + # state:tg:config """ - return f"{qos}/{tenant}/{namespace}/{queue_name}" - + return f"{cls}:{topicspace}:{topic}" diff --git a/trustgraph-base/trustgraph/schema/knowledge/document.py b/trustgraph-base/trustgraph/schema/knowledge/document.py index c75a1227..fc7273ef 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/document.py +++ b/trustgraph-base/trustgraph/schema/knowledge/document.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from ..core.metadata import Metadata -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index a8bae35c..2cfd08cf 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from ..core.metadata import Metadata from ..core.primitives import Term, RowSchema -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/graph.py b/trustgraph-base/trustgraph/schema/knowledge/graph.py index b4a05084..a15676ab 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/graph.py +++ b/trustgraph-base/trustgraph/schema/knowledge/graph.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from ..core.primitives import Term, Triple from ..core.metadata import Metadata -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index cffcbac7..0c4a9f7c 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from ..core.primitives import Triple, Error -from ..core.topic import topic +from ..core.topic import queue from ..core.metadata import Metadata from .document import Document, TextDocument from .graph import Triples @@ -52,9 +52,5 @@ class KnowledgeResponse: triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None -knowledge_request_queue = topic( - 'knowledge', qos='q0', namespace='request' -) -knowledge_response_queue = topic( - 'knowledge', qos='q0', namespace='response', -) +knowledge_request_queue = queue('knowledge', cls='request') +knowledge_response_queue = queue('knowledge', cls='response') diff --git a/trustgraph-base/trustgraph/schema/knowledge/nlp.py b/trustgraph-base/trustgraph/schema/knowledge/nlp.py index 10b5f215..84e2f080 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/nlp.py +++ b/trustgraph-base/trustgraph/schema/knowledge/nlp.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/object.py b/trustgraph-base/trustgraph/schema/knowledge/object.py index 39b0095f..4b51bbe1 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/object.py +++ b/trustgraph-base/trustgraph/schema/knowledge/object.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from ..core.metadata import Metadata -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/rows.py b/trustgraph-base/trustgraph/schema/knowledge/rows.py index ca2131df..015affe1 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/rows.py +++ b/trustgraph-base/trustgraph/schema/knowledge/rows.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from ..core.metadata import Metadata from ..core.primitives import RowSchema -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/structured.py b/trustgraph-base/trustgraph/schema/knowledge/structured.py index c227d767..52bfec27 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/structured.py +++ b/trustgraph-base/trustgraph/schema/knowledge/structured.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from ..core.metadata import Metadata -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index fdb9e391..2a966dd4 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from typing import Optional -from ..core.topic import topic from ..core.primitives import Error ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py index 74381abb..f4b5fc6e 100644 --- a/trustgraph-base/trustgraph/schema/services/collection.py +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime from ..core.primitives import Error -from ..core.topic import topic +from ..core.topic import queue ############################################################################ @@ -50,10 +50,6 @@ class CollectionManagementResponse: # Topics -collection_request_queue = topic( - 'collection', qos='q0', namespace='request' -) -collection_response_queue = topic( - 'collection', qos='q0', namespace='response' -) +collection_request_queue = queue('collection', cls='request') +collection_response_queue = queue('collection', cls='response') diff --git a/trustgraph-base/trustgraph/schema/services/config.py b/trustgraph-base/trustgraph/schema/services/config.py index 38bd1cbf..36e55674 100644 --- a/trustgraph-base/trustgraph/schema/services/config.py +++ b/trustgraph-base/trustgraph/schema/services/config.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field -from ..core.topic import topic +from ..core.topic import queue from ..core.primitives import Error ############################################################################ @@ -60,15 +60,9 @@ class ConfigPush: version: int = 0 config: dict[str, dict[str, str]] = field(default_factory=dict) -config_request_queue = topic( - 'config', qos='q0', namespace='request' -) -config_response_queue = topic( - 'config', qos='q0', namespace='response' -) -config_push_queue = topic( - 'config', qos='q2', namespace='config' -) +config_request_queue = queue('config', cls='request') +config_response_queue = queue('config', cls='response') +config_push_queue = queue('config', cls='state') ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/flow.py b/trustgraph-base/trustgraph/schema/services/flow.py index cf62c84d..0d497dd7 100644 --- a/trustgraph-base/trustgraph/schema/services/flow.py +++ b/trustgraph-base/trustgraph/schema/services/flow.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field -from ..core.topic import topic +from ..core.topic import queue from ..core.primitives import Error ############################################################################ @@ -61,12 +61,8 @@ class FlowResponse: # Everything error: Error | None = None -flow_request_queue = topic( - 'flow', qos='q0', namespace='request' -) -flow_response_queue = topic( - 'flow', qos='q0', namespace='response' -) +flow_request_queue = queue('flow', cls='request') +flow_response_queue = queue('flow', cls='response') ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index f1ab360f..51d0d5a5 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from ..core.primitives import Triple, Error -from ..core.topic import topic +from ..core.topic import queue from ..core.metadata import Metadata # Note: Document imports will be updated after knowledge schemas are converted @@ -220,9 +220,5 @@ class LibrarianResponse: # FIXME: Is this right? Using persistence on librarian so that # message chunking works -librarian_request_queue = topic( - 'librarian', qos='q1', namespace='request' -) -librarian_response_queue = topic( - 'librarian', qos='q1', namespace='response', -) +librarian_request_queue = queue('librarian-request', cls='flow') +librarian_response_queue = queue('librarian-response', cls='flow') diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index 681638c3..0fd6ab90 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field -from ..core.topic import topic from ..core.primitives import Error ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/lookup.py b/trustgraph-base/trustgraph/schema/services/lookup.py index d944fb89..3c661e4c 100644 --- a/trustgraph-base/trustgraph/schema/services/lookup.py +++ b/trustgraph-base/trustgraph/schema/services/lookup.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from ..core.primitives import Error, Term, Triple -from ..core.topic import topic from ..core.metadata import Metadata ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/nlp_query.py b/trustgraph-base/trustgraph/schema/services/nlp_query.py index 6cd65f0e..73780567 100644 --- a/trustgraph-base/trustgraph/schema/services/nlp_query.py +++ b/trustgraph-base/trustgraph/schema/services/nlp_query.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from ..core.primitives import Error -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py index f7a31c14..f7388102 100644 --- a/trustgraph-base/trustgraph/schema/services/prompt.py +++ b/trustgraph-base/trustgraph/schema/services/prompt.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from ..core.primitives import Error -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 7a65f775..f9f08658 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from ..core.primitives import Error, Term, Triple -from ..core.topic import topic +from ..core.topic import queue ############################################################################ @@ -69,12 +69,8 @@ class DocumentEmbeddingsResponse: error: Error | None = None chunks: list[ChunkMatch] = field(default_factory=list) -document_embeddings_request_queue = topic( - "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' -) -document_embeddings_response_queue = topic( - "document-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow' -) +document_embeddings_request_queue = queue('document-embeddings', cls='request') +document_embeddings_response_queue = queue('document-embeddings', cls='response') ############################################################################ @@ -104,9 +100,5 @@ class RowEmbeddingsResponse: error: Error | None = None matches: list[RowIndexMatch] = field(default_factory=list) -row_embeddings_request_queue = topic( - "row-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' -) -row_embeddings_response_queue = topic( - "row-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow' -) \ No newline at end of file +row_embeddings_request_queue = queue('row-embeddings', cls='request') +row_embeddings_response_queue = queue('row-embeddings', cls='response') \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index f5ac73d3..a4621549 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from ..core.topic import topic from ..core.primitives import Error, Term ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/rows_query.py b/trustgraph-base/trustgraph/schema/services/rows_query.py index a4818329..e3c4f14c 100644 --- a/trustgraph-base/trustgraph/schema/services/rows_query.py +++ b/trustgraph-base/trustgraph/schema/services/rows_query.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from typing import Optional from ..core.primitives import Error -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/structured_query.py b/trustgraph-base/trustgraph/schema/services/structured_query.py index ae1eaa5f..5f54ac16 100644 --- a/trustgraph-base/trustgraph/schema/services/structured_query.py +++ b/trustgraph-base/trustgraph/schema/services/structured_query.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from ..core.primitives import Error -from ..core.topic import topic ############################################################################ diff --git a/trustgraph-cli/trustgraph/cli/dump_queues.py b/trustgraph-cli/trustgraph/cli/dump_queues.py index 4df61cc3..eb7898c2 100644 --- a/trustgraph-cli/trustgraph/cli/dump_queues.py +++ b/trustgraph-cli/trustgraph/cli/dump_queues.py @@ -8,8 +8,6 @@ message flows, diagnosing stuck services, and understanding system behavior. Uses TrustGraph's Subscriber abstraction for future-proof pub/sub compatibility. """ -import pulsar -from pulsar.schema import BytesSchema import sys import json import asyncio @@ -17,7 +15,7 @@ from datetime import datetime import argparse from trustgraph.base.subscriber import Subscriber -from trustgraph.base.pubsub import get_pubsub +from trustgraph.base.pubsub import get_pubsub, add_pubsub_args def decode_json_strings(obj): """Recursively decode JSON-encoded string values within a dict/list.""" @@ -172,15 +170,13 @@ async def log_writer(central_queue, file_handle, shutdown_event, console_output= break -async def async_main(queues, output_file, pulsar_host, listener_name, subscriber_name, append_mode): +async def async_main(queues, output_file, subscriber_name, append_mode, **pubsub_config): """ Main async function to monitor multiple queues concurrently. Args: queues: List of queue names to monitor output_file: Path to output file - pulsar_host: Pulsar connection URL - listener_name: Pulsar listener name subscriber_name: Base name for subscribers append_mode: Whether to append to existing file """ @@ -194,9 +190,9 @@ async def async_main(queues, output_file, pulsar_host, listener_name, subscriber # Create backend connection try: - backend = get_pubsub(pulsar_host=pulsar_host, pulsar_listener=listener_name, pubsub_backend='pulsar') + backend = get_pubsub(**pubsub_config) except Exception as e: - print(f"Error connecting to backend at {pulsar_host}: {e}", file=sys.stderr) + print(f"Error connecting to backend: {e}", file=sys.stderr) sys.exit(1) # Create Subscribers and central queue @@ -291,25 +287,20 @@ def main(): description='Monitor and dump messages from multiple Pulsar queues', epilog=""" Examples: - # Monitor agent and prompt queues - tg-dump-queues non-persistent://tg/request/agent:default \\ - non-persistent://tg/request/prompt:default + # Monitor agent and prompt flow queues + tg-dump-queues flow:tg:agent-request:default \\ + flow:tg:prompt-request:default # Monitor with custom output file - tg-dump-queues non-persistent://tg/request/agent:default \\ + tg-dump-queues flow:tg:agent-request:default \\ --output debug.log # Append to existing log file - tg-dump-queues non-persistent://tg/request/agent:default \\ + tg-dump-queues flow:tg:agent-request:default \\ --output queue.log --append -Common queue patterns: - - Agent requests: non-persistent://tg/request/agent:default - - Agent responses: non-persistent://tg/response/agent:default - - Prompt requests: non-persistent://tg/request/prompt:default - - Prompt responses: non-persistent://tg/response/prompt:default - - LLM requests: non-persistent://tg/request/text-completion:default - - LLM responses: non-persistent://tg/response/text-completion:default + # Raw Pulsar URIs also accepted + tg-dump-queues persistent://tg/flow/agent-request:default IMPORTANT: This tool subscribes to queues without a schema (schema-less mode). To avoid @@ -340,17 +331,7 @@ IMPORTANT: help='Append to output file instead of overwriting' ) - parser.add_argument( - '--pulsar-host', - default='pulsar://localhost:6650', - help='Pulsar host URL (default: pulsar://localhost:6650)' - ) - - parser.add_argument( - '--listener-name', - default='localhost', - help='Pulsar listener name (default: localhost)' - ) + add_pubsub_args(parser, standalone=True) parser.add_argument( '--subscriber', @@ -371,10 +352,12 @@ IMPORTANT: asyncio.run(async_main( queues=queues, output_file=args.output, - pulsar_host=args.pulsar_host, - listener_name=args.listener_name, subscriber_name=args.subscriber, - append_mode=args.append + append_mode=args.append, + pubsub_backend=args.pubsub_backend, + pulsar_host=args.pulsar_host, + pulsar_api_key=args.pulsar_api_key, + pulsar_listener=args.pulsar_listener, )) except KeyboardInterrupt: # Already handled in async_main diff --git a/trustgraph-cli/trustgraph/cli/init_trustgraph.py b/trustgraph-cli/trustgraph/cli/init_trustgraph.py index bed56a73..02456b1c 100644 --- a/trustgraph-cli/trustgraph/cli/init_trustgraph.py +++ b/trustgraph-cli/trustgraph/cli/init_trustgraph.py @@ -137,7 +137,7 @@ def init( } }) - ensure_namespace(pulsar_admin_url, tenant, "config", { + ensure_namespace(pulsar_admin_url, tenant, "state", { "retention_policies": { "retentionSizeInMB": 10, "retentionTimeInMinutes": -1, diff --git a/trustgraph-cli/trustgraph/cli/monitor_prompts.py b/trustgraph-cli/trustgraph/cli/monitor_prompts.py index 974cfbcd..c3b71afb 100644 --- a/trustgraph-cli/trustgraph/cli/monitor_prompts.py +++ b/trustgraph-cli/trustgraph/cli/monitor_prompts.py @@ -1,7 +1,7 @@ """ Monitor prompt request/response queues and log activity with timing. -Subscribes to prompt request and response Pulsar queues, correlates +Subscribes to prompt request and response queues, correlates them by message ID, and logs a summary of each request/response with elapsed time. Streaming responses are accumulated and shown once at completion. @@ -19,8 +19,7 @@ import argparse from datetime import datetime from collections import OrderedDict -import pulsar -from pulsar.schema import BytesSchema +from trustgraph.base.pubsub import get_pubsub, add_pubsub_args default_flow = "default" @@ -85,7 +84,7 @@ def format_terms(terms, max_lines, max_width): def parse_raw_message(msg): - """Parse a raw Pulsar message into (correlation_id, body_dict).""" + """Parse a raw message into (correlation_id, body_dict).""" try: props = msg.properties() corr_id = props.get("id", "") @@ -94,53 +93,46 @@ def parse_raw_message(msg): try: value = msg.value() - if isinstance(value, bytes): - value = value.decode("utf-8") - body = json.loads(value) if isinstance(value, str) else {} + if isinstance(value, dict): + body = value + elif isinstance(value, bytes): + body = json.loads(value.decode("utf-8")) + elif isinstance(value, str): + body = json.loads(value) + else: + body = {} except Exception: body = {} return corr_id, body -def receive_with_timeout(consumer, timeout_ms=500): - """Receive a message with timeout, returning None on timeout.""" - try: - return consumer.receive(timeout_millis=timeout_ms) - except Exception: - return None +async def monitor(flow, queue_type, max_lines, max_width, **config): - -async def monitor(flow, queue_type, max_lines, max_width, - pulsar_host, listener_name): - - request_queue = f"non-persistent://tg/request/{queue_type}:{flow}" - response_queue = f"non-persistent://tg/response/{queue_type}:{flow}" + request_queue = f"request:tg:{queue_type}:{flow}" + response_queue = f"response:tg:{queue_type}:{flow}" print(f"Monitoring prompt queues:") print(f" Request: {request_queue}") print(f" Response: {response_queue}") print(f"Press Ctrl+C to stop\n") - client = pulsar.Client( - pulsar_host, - listener_name=listener_name, + backend = get_pubsub(**config) + + req_consumer = backend.create_consumer( + topic=request_queue, + subscription="prompt-monitor-req", + schema=None, + consumer_type='shared', + initial_position='latest', ) - req_consumer = client.subscribe( - request_queue, - subscription_name="prompt-monitor-req", - consumer_type=pulsar.ConsumerType.Shared, - schema=BytesSchema(), - initial_position=pulsar.InitialPosition.Latest, - ) - - resp_consumer = client.subscribe( - response_queue, - subscription_name="prompt-monitor-resp", - consumer_type=pulsar.ConsumerType.Shared, - schema=BytesSchema(), - initial_position=pulsar.InitialPosition.Latest, + resp_consumer = backend.create_consumer( + topic=response_queue, + subscription="prompt-monitor-resp", + schema=None, + consumer_type='shared', + initial_position='latest', ) # Track in-flight requests: corr_id -> (timestamp, template_id) @@ -156,8 +148,8 @@ async def monitor(flow, queue_type, max_lines, max_width, got_message = False # Poll request queue - msg = receive_with_timeout(req_consumer, 100) - if msg: + try: + msg = req_consumer.receive(timeout_millis=100) got_message = True timestamp = datetime.now() corr_id, body = parse_raw_message(msg) @@ -182,10 +174,12 @@ async def monitor(flow, queue_type, max_lines, max_width, print(format_terms(terms, max_lines, max_width)) req_consumer.acknowledge(msg) + except TimeoutError: + pass # Poll response queue - msg = receive_with_timeout(resp_consumer, 100) - if msg: + try: + msg = resp_consumer.receive(timeout_millis=100) got_message = True timestamp = datetime.now() corr_id, body = parse_raw_message(msg) @@ -265,6 +259,8 @@ async def monitor(flow, queue_type, max_lines, max_width, print(f" {truncated}") resp_consumer.acknowledge(msg) + except TimeoutError: + pass if not got_message: await asyncio.sleep(0.05) @@ -274,7 +270,7 @@ async def monitor(flow, queue_type, max_lines, max_width, finally: req_consumer.close() resp_consumer.close() - client.close() + backend.close() def main(): @@ -310,17 +306,7 @@ def main(): help=f"Max width per line (default: {default_max_width})", ) - parser.add_argument( - "--pulsar-host", - default="pulsar://localhost:6650", - help="Pulsar host URL (default: pulsar://localhost:6650)", - ) - - parser.add_argument( - "--listener-name", - default="localhost", - help="Pulsar listener name (default: localhost)", - ) + add_pubsub_args(parser, standalone=True) args = parser.parse_args() @@ -331,7 +317,9 @@ def main(): max_lines=args.max_lines, max_width=args.max_width, pulsar_host=args.pulsar_host, - listener_name=args.listener_name, + pulsar_api_key=args.pulsar_api_key, + pulsar_listener=args.pulsar_listener, + pubsub_backend=args.pubsub_backend, )) except KeyboardInterrupt: pass diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 88e29116..02aa7d78 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -145,7 +145,7 @@ class Processor(FlowProcessor): try: # Convert Pulsar RowSchema to JSON-serializable dict - schema_dict = row_schema_translator.from_pulsar(schema) + schema_dict = row_schema_translator.encode(schema) # Use prompt client to extract rows based on schema objects = await flow("prompt-request").extract_objects( diff --git a/trustgraph-flow/trustgraph/gateway/config/receiver.py b/trustgraph-flow/trustgraph/gateway/config/receiver.py index 4bf39ccd..d956c7c6 100755 --- a/trustgraph-flow/trustgraph/gateway/config/receiver.py +++ b/trustgraph-flow/trustgraph/gateway/config/receiver.py @@ -23,7 +23,6 @@ import uuid logger = logging.getLogger(__name__) import json -import pulsar from prometheus_client import start_http_server from ... schema import ConfigPush, config_push_queue diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/agent.py b/trustgraph-flow/trustgraph/gateway/dispatch/agent.py index 8867956d..5b97b297 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/agent.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/agent.py @@ -25,8 +25,8 @@ class AgentRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("agent") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py index 2fa3759d..544a412d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py @@ -29,8 +29,8 @@ class CollectionManagementRequestor(ServiceRequestor): def to_request(self, body): print("REQUEST", body, flush=True) - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): print("RESPONSE", message, flush=True) - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/config.py b/trustgraph-flow/trustgraph/gateway/dispatch/config.py index 9d40e8cc..3c591a81 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/config.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/config.py @@ -30,8 +30,8 @@ class ConfigRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("config") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py index bd5f9666..199b5a42 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py @@ -44,7 +44,7 @@ class DocumentEmbeddingsImport: async def receive(self, msg): data = msg.json() - elt = self.translator.to_pulsar(data) + elt = self.translator.decode(data) await self.publisher.send(None, elt) async def run(self): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py index 650d4f40..80a0935f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py @@ -25,7 +25,7 @@ class DocumentEmbeddingsQueryRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("document-embeddings-query") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py index eb68b0b1..67800f21 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py @@ -23,5 +23,5 @@ class DocumentLoad(ServiceSender): def to_request(self, body): logger.info("Document received") - return self.translator.to_pulsar(body) + return self.translator.decode(body) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py index 83b3cb9a..55e20bfc 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py @@ -25,8 +25,8 @@ class DocumentRagRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("document-rag") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py b/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py index 6c1b55ba..99994f2a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py @@ -25,8 +25,8 @@ class EmbeddingsRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("embeddings") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/flow.py b/trustgraph-flow/trustgraph/gateway/dispatch/flow.py index be91995d..6f901c1f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/flow.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/flow.py @@ -30,8 +30,8 @@ class FlowRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("flow") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py index a7bb1bd8..3081ad59 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py @@ -25,8 +25,8 @@ class GraphEmbeddingsQueryRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("graph-embeddings-query") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py index a0299a43..9b8feea4 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py @@ -25,8 +25,8 @@ class GraphRagRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("graph-rag") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py b/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py index 83aefbd0..90f7f89c 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py @@ -33,8 +33,8 @@ class KnowledgeRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("knowledge") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py b/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py index bbf7190e..8f33f9c1 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py @@ -40,8 +40,8 @@ class LibrarianRequestor(ServiceRequestor): body = body.copy() body["content"] = content - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py b/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py index a5f9398e..9be08aff 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py @@ -25,8 +25,8 @@ class McpToolRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("tool") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py index 3a6314f2..7152d6c0 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py @@ -24,7 +24,7 @@ class NLPQueryRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("nlp-query") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) \ No newline at end of file + return self.response_translator.encode_with_completion(message) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py b/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py index 23017733..2304cbdb 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py @@ -27,8 +27,8 @@ class PromptRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("prompt") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py index 8b139fc2..9a0704ca 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py @@ -25,7 +25,7 @@ class RowEmbeddingsQueryRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py index 57435be8..e285d9c8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py @@ -24,7 +24,7 @@ class RowsQueryRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("rows-query") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index f42eee02..7267e320 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -11,22 +11,22 @@ _triple_translator = TripleTranslator() def to_value(x): """Convert dict to Term. Delegates to TermTranslator.""" - return _term_translator.to_pulsar(x) + return _term_translator.decode(x) def to_subgraph(x): """Convert list of dicts to list of Triples. Delegates to TripleTranslator.""" - return [_triple_translator.to_pulsar(t) for t in x] + return [_triple_translator.decode(t) for t in x] def serialize_value(v): """Convert Term to dict. Delegates to TermTranslator.""" - return _term_translator.from_pulsar(v) + return _term_translator.encode(v) def serialize_triple(t): """Convert Triple to dict. Delegates to TripleTranslator.""" - return _triple_translator.from_pulsar(t) + return _triple_translator.encode(t) def serialize_subgraph(sg): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py b/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py index 895b55be..5bf8f3e5 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py @@ -24,7 +24,7 @@ class StructuredDiagRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("structured-diag") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) \ No newline at end of file + return self.response_translator.encode_with_completion(message) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py index 9a9fbb6a..19508f97 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py @@ -24,7 +24,7 @@ class StructuredQueryRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("structured-query") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) \ No newline at end of file + return self.response_translator.encode_with_completion(message) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py b/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py index 0e77584e..a7c9f6e6 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py @@ -25,8 +25,8 @@ class TextCompletionRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("text-completion") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py b/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py index b2562938..c21f8261 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py @@ -23,5 +23,5 @@ class TextLoad(ServiceSender): def to_request(self, body): logger.info("Text document received") - return self.translator.to_pulsar(body) + return self.translator.decode(body) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py index 6b306139..c9c66705 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py @@ -25,8 +25,8 @@ class TriplesQueryRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("triples-query") def to_request(self, body): - return self.request_translator.to_pulsar(body) + return self.request_translator.decode(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index aaa6f725..cdf5daba 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -18,7 +18,6 @@ from . dispatch.manager import DispatcherManager from . endpoint.manager import EndpointManager -import pulsar from prometheus_client import start_http_server # Import default queue names diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index ab13ccbc..d31d6223 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -3,7 +3,6 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph. """ -import pulsar import base64 import os import argparse diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index 210ea53d..ac8d05c4 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -3,7 +3,6 @@ Graph writer. Input is graph edge. Writes edges to FalkorDB graph. """ -import pulsar import base64 import os import argparse diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 55d4dee1..7864ac80 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -3,7 +3,6 @@ Graph writer. Input is graph edge. Writes edges to Memgraph. """ -import pulsar import base64 import os import argparse diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index 4a85a273..3db712fb 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -3,7 +3,6 @@ Graph writer. Input is graph edge. Writes edges to Neo4j graph. """ -import pulsar import base64 import os import argparse From 24f0190ce79e703d2b5585304484747fc23cb57e Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 2 Apr 2026 12:47:16 +0100 Subject: [PATCH 26/37] RabbitMQ pub/sub backend with topic exchange architecture (#752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a RabbitMQ backend as an alternative to Pulsar, selectable via PUBSUB_BACKEND=rabbitmq. Both backends implement the same PubSubBackend protocol — no application code changes needed to switch. RabbitMQ topology: - Single topic exchange per topicspace (e.g. 'tg') - Routing key derived from queue class and topic name - Shared consumers: named queue bound to exchange (competing, round-robin) - Exclusive consumers: anonymous auto-delete queue (broadcast, each gets every message). Used by Subscriber and config push consumer. - Thread-local producer connections (pika is not thread-safe) - Push-based consumption via basic_consume with process_data_events for heartbeat processing Consumer model changes: - Consumer class creates one backend consumer per concurrent task (required for pika thread safety, harmless for Pulsar) - Consumer class accepts consumer_type parameter - Subscriber passes consumer_type='exclusive' for broadcast semantics - Config push consumer uses consumer_type='exclusive' so every processor instance receives config updates - handle_one_from_queue receives consumer as parameter for correct per-connection ack/nack LibrarianClient: - New shared client class replacing duplicated librarian request-response code across 6+ services (chunking, decoders, RAG, etc.) - Uses stream-document instead of get-document-content for fetching document content in 1MB chunks (avoids broker message size limits) - Standalone object (self.librarian = LibrarianClient(...)) not a mixin - get-document-content marked deprecated in schema and OpenAPI spec Serialisation: - Extracted dataclass_to_dict/dict_to_dataclass to shared serialization.py (used by both Pulsar and RabbitMQ backends) Librarian queues: - Changed from flow class (persistent) back to request/response class now that stream-document eliminates large single messages - API upload chunk size reduced from 5MB to 3MB to stay under broker limits after base64 encoding Factory and CLI: - get_pubsub() handles 'rabbitmq' backend with RabbitMQ connection params - add_pubsub_args() includes RabbitMQ options (host, port, credentials) - add_pubsub_args(standalone=True) defaults to localhost for CLI tools - init_trustgraph skips Pulsar admin setup for non-Pulsar backends - tg-dump-queues and tg-monitor-prompts use backend abstraction - BaseClient and ConfigClient accept generic pubsub config --- Makefile | 4 +- .../schemas/librarian/LibrarianRequest.yaml | 29 ++ .../test_chunking/test_recursive_chunker.py | 40 +- .../unit/test_chunking/test_token_chunker.py | 44 +- .../test_consumer_concurrency.py | 18 +- .../test_mistral_ocr_processor.py | 18 +- tests/unit/test_decoding/test_pdf_decoder.py | 36 +- .../test_decoding/test_universal_processor.py | 32 +- tests/unit/test_pubsub/test_queue_naming.py | 35 +- .../unit/test_pubsub/test_rabbitmq_backend.py | 107 +++++ trustgraph-base/pyproject.toml | 1 + trustgraph-base/trustgraph/api/library.py | 5 +- trustgraph-base/trustgraph/base/__init__.py | 1 + .../trustgraph/base/async_processor.py | 10 +- .../trustgraph/base/chunking_service.py | 214 ++-------- trustgraph-base/trustgraph/base/consumer.py | 99 ++--- .../trustgraph/base/librarian_client.py | 246 +++++++++++ trustgraph-base/trustgraph/base/pubsub.py | 61 ++- .../trustgraph/base/pulsar_backend.py | 112 +---- .../trustgraph/base/rabbitmq_backend.py | 390 ++++++++++++++++++ .../trustgraph/base/serialization.py | 115 ++++++ trustgraph-base/trustgraph/base/subscriber.py | 2 +- trustgraph-base/trustgraph/clients/base.py | 11 +- .../trustgraph/clients/config_client.py | 8 +- .../trustgraph/schema/services/library.py | 9 +- trustgraph-cli/trustgraph/cli/dump_queues.py | 6 +- .../trustgraph/cli/init_trustgraph.py | 86 ++-- .../trustgraph/cli/monitor_prompts.py | 6 +- .../trustgraph/chunking/recursive/chunker.py | 2 +- .../trustgraph/chunking/token/chunker.py | 2 +- .../decoding/mistral_ocr/processor.py | 177 +------- .../trustgraph/decoding/pdf/pdf_decoder.py | 191 +-------- trustgraph-flow/trustgraph/gateway/service.py | 27 +- .../trustgraph/retrieval/document_rag/rag.py | 135 +----- .../trustgraph/decoding/ocr/pdf_decoder.py | 177 +------- .../decoding/universal/processor.py | 134 +----- 36 files changed, 1277 insertions(+), 1313 deletions(-) create mode 100644 tests/unit/test_pubsub/test_rabbitmq_backend.py create mode 100644 trustgraph-base/trustgraph/base/librarian_client.py create mode 100644 trustgraph-base/trustgraph/base/rabbitmq_backend.py create mode 100644 trustgraph-base/trustgraph/base/serialization.py diff --git a/Makefile b/Makefile index 197a6c63..4d79f554 100644 --- a/Makefile +++ b/Makefile @@ -77,8 +77,8 @@ some-containers: -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . -# ${DOCKER} build -f containers/Containerfile.unstructured \ -# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . + ${DOCKER} build -f containers/Containerfile.unstructured \ + -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . # ${DOCKER} build -f containers/Containerfile.vertexai \ # -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . # ${DOCKER} build -f containers/Containerfile.mcp \ diff --git a/specs/api/components/schemas/librarian/LibrarianRequest.yaml b/specs/api/components/schemas/librarian/LibrarianRequest.yaml index 18aa94b1..eed999f0 100644 --- a/specs/api/components/schemas/librarian/LibrarianRequest.yaml +++ b/specs/api/components/schemas/librarian/LibrarianRequest.yaml @@ -3,6 +3,9 @@ description: | Librarian service request for document library management. Operations: add-document, remove-document, list-documents, + get-document-metadata, stream-document, add-child-document, + list-children, begin-upload, upload-chunk, complete-upload, + abort-upload, get-upload-status, list-uploads, start-processing, stop-processing, list-processing required: - operation @@ -13,6 +16,17 @@ properties: - add-document - remove-document - list-documents + - get-document-metadata + - get-document-content + - stream-document + - add-child-document + - list-children + - begin-upload + - upload-chunk + - complete-upload + - abort-upload + - get-upload-status + - list-uploads - start-processing - stop-processing - list-processing @@ -21,6 +35,21 @@ properties: - `add-document`: Add document to library - `remove-document`: Remove document from library - `list-documents`: List documents in library + - `get-document-metadata`: Get document metadata + - `get-document-content`: Get full document content in a single response. + **Deprecated** — use `stream-document` instead. Fails for documents + exceeding the broker's max message size. + - `stream-document`: Stream document content in chunks. Each response + includes `chunk_index` and `is_final`. Preferred over `get-document-content` + for all document sizes. + - `add-child-document`: Add a child document (e.g. page, chunk) + - `list-children`: List child documents of a parent + - `begin-upload`: Start a chunked upload session + - `upload-chunk`: Upload a chunk of data + - `complete-upload`: Finalize a chunked upload + - `abort-upload`: Cancel a chunked upload + - `get-upload-status`: Check upload progress + - `list-uploads`: List active upload sessions - `start-processing`: Start processing library documents - `stop-processing`: Stop library processing - `list-processing`: List processing status diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py index ae05d22c..a5ec59c8 100644 --- a/tests/unit/test_chunking/test_recursive_chunker.py +++ b/tests/unit/test_chunking/test_recursive_chunker.py @@ -24,8 +24,8 @@ class MockAsyncProcessor: class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): """Test Recursive chunker functionality""" - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) def test_processor_initialization_basic(self, mock_producer, mock_consumer): """Test basic processor initialization""" @@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] assert len(param_specs) == 2 - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-size parameter override""" @@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 2000, # Override chunk size "chunk-overlap": None # Use default chunk overlap }.get(param) @@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 2000 # Should use overridden value assert chunk_overlap == 100 # Should use default value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-overlap parameter override""" @@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": None, # Use default chunk size "chunk-overlap": 200 # Override chunk overlap }.get(param) @@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 1000 # Should use default value assert chunk_overlap == 200 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): """Test chunk_document with both chunk-size and chunk-overlap overrides""" @@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 1500, # Override chunk size "chunk-overlap": 150 # Override chunk overlap }.get(param) @@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 1500 # Should use overridden value assert chunk_overlap == 150 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): @@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid waiting for librarian response - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") # Mock message with TextDocument mock_message = MagicMock() @@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_producer = AsyncMock() mock_triples_producer = AsyncMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 1500, "chunk-overlap": 150, + }.get(param) + mock_flow.side_effect = lambda name: { "output": mock_producer, "triples": mock_triples_producer, - }.get(param) + }.get(name) # Act await processor.on_message(mock_message, mock_consumer, mock_flow) @@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): sent_chunk = mock_producer.send.call_args[0][0] assert isinstance(sent_chunk, Chunk) - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): """Test chunk_document when no parameters are overridden (flow returns None)""" @@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.return_value = None # No overrides + mock_flow.parameters.get.return_value = None # No overrides # Act chunk_size, chunk_overlap = await processor.chunk_document( diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index 2ed37391..f3f83904 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -24,8 +24,8 @@ class MockAsyncProcessor: class TestTokenChunkerSimple(IsolatedAsyncioTestCase): """Test Token chunker functionality""" - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) def test_processor_initialization_basic(self, mock_producer, mock_consumer): """Test basic processor initialization""" @@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] assert len(param_specs) == 2 - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-size parameter override""" @@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 400, # Override chunk size "chunk-overlap": None # Use default chunk overlap }.get(param) @@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 400 # Should use overridden value assert chunk_overlap == 15 # Should use default value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-overlap parameter override""" @@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": None, # Use default chunk size "chunk-overlap": 25 # Override chunk overlap }.get(param) @@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 250 # Should use default value assert chunk_overlap == 25 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): """Test chunk_document with both chunk-size and chunk-overlap overrides""" @@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 350, # Override chunk size "chunk-overlap": 30 # Override chunk overlap }.get(param) @@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 350 # Should use overridden value assert chunk_overlap == 30 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.chunking.token.chunker.TokenTextSplitter') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): @@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid librarian producer interactions - processor.save_child_document = AsyncMock(return_value="chunk-id") + processor.librarian.save_child_document = AsyncMock(return_value="chunk-id") # Mock message with TextDocument mock_message = MagicMock() @@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_producer = AsyncMock() mock_triples_producer = AsyncMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 400, "chunk-overlap": 40, + }.get(param) + mock_flow.side_effect = lambda name: { "output": mock_producer, "triples": mock_triples_producer, - }.get(param) + }.get(name) # Act await processor.on_message(mock_message, mock_consumer, mock_flow) @@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): sent_chunk = mock_producer.send.call_args[0][0] assert isinstance(sent_chunk, Chunk) - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): """Test chunk_document when no parameters are overridden (flow returns None)""" @@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.return_value = None # No overrides + mock_flow.parameters.get.return_value = None # No overrides # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 250 # Should use default value assert chunk_overlap == 15 # Should use default value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer): """Test that token chunker has different defaults than recursive chunker""" diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 32a6559b..3869aaf3 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -83,7 +83,7 @@ class TestTaskGroupConcurrency: call_count = 0 original_running = True - async def mock_consume(): + async def mock_consume(backend_consumer): nonlocal call_count call_count += 1 # Wait a bit to let all tasks start, then signal stop @@ -107,7 +107,7 @@ class TestTaskGroupConcurrency: consumer = _make_consumer(concurrency=1) call_count = 0 - async def mock_consume(): + async def mock_consume(backend_consumer): nonlocal call_count call_count += 1 await asyncio.sleep(0.01) @@ -147,7 +147,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) assert call_count == 2 consumer.consumer.acknowledge.assert_called_once_with(mock_msg) @@ -166,7 +166,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) consumer.consumer.negative_acknowledge.assert_called_with(mock_msg) consumer.consumer.acknowledge.assert_not_called() @@ -185,7 +185,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) assert call_count == 1 consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg) @@ -197,7 +197,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) consumer.consumer.acknowledge.assert_called_once_with(mock_msg) @@ -219,7 +219,7 @@ class TestMetricsIntegration: mock_metrics.record_time.return_value.__exit__ = MagicMock() consumer.metrics = mock_metrics - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) mock_metrics.process.assert_called_once_with("success") @@ -235,7 +235,7 @@ class TestMetricsIntegration: mock_metrics = MagicMock() consumer.metrics = mock_metrics - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) mock_metrics.process.assert_called_once_with("error") @@ -261,7 +261,7 @@ class TestMetricsIntegration: mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False) consumer.metrics = mock_metrics - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) mock_metrics.rate_limit.assert_called_once() diff --git a/tests/unit/test_decoding/test_mistral_ocr_processor.py b/tests/unit/test_decoding/test_mistral_ocr_processor.py index 3243666c..2b8c25e2 100644 --- a/tests/unit/test_decoding/test_mistral_ocr_processor.py +++ b/tests/unit/test_decoding/test_mistral_ocr_processor.py @@ -25,8 +25,8 @@ class MockAsyncProcessor: class TestMistralOcrProcessor(IsolatedAsyncioTestCase): """Test Mistral OCR processor functionality""" - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_initialization_with_api_key( @@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): assert consumer_specs[0].name == "input" assert consumer_specs[0].schema == Document - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_initialization_without_api_key( self, mock_producer, mock_consumer @@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): with pytest.raises(RuntimeError, match="Mistral API key not specified"): Processor(**config) - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_ocr_single_chunk( @@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): ) mock_mistral.ocr.process.assert_called_once() - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_success( @@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): ] # Mock save_child_document - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") with patch.object(processor, 'ocr', return_value=ocr_result): await processor.on_message(mock_msg, None, mock_flow) diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py index 22659479..d2183c0c 100644 --- a/tests/unit/test_decoding/test_pdf_decoder.py +++ b/tests/unit/test_decoding/test_pdf_decoder.py @@ -24,12 +24,10 @@ class MockAsyncProcessor: class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): """Test PDF decoder processor functionality""" - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_processor_initialization(self, mock_producer, mock_consumer): """Test PDF decoder processor initialization""" config = { 'id': 'test-pdf-decoder', @@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): assert consumer_specs[0].name == "input" assert consumer_specs[0].schema == Document - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test successful PDF processing""" # Mock PDF content pdf_content = b"fake pdf content" @@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid waiting for librarian response - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") await processor.on_message(mock_msg, None, mock_flow) @@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): # Verify triples were sent for each page (provenance) assert mock_triples_flow.send.call_count == 2 - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test handling of empty PDF""" pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') @@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): mock_output_flow.send.assert_not_called() - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test handling of unicode content in PDF""" pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') @@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid waiting for librarian response - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") await processor.on_message(mock_msg, None, mock_flow) diff --git a/tests/unit/test_decoding/test_universal_processor.py b/tests/unit/test_decoding/test_universal_processor.py index 8d2e116e..4daa9b68 100644 --- a/tests/unit/test_decoding/test_universal_processor.py +++ b/tests/unit/test_decoding/test_universal_processor.py @@ -142,8 +142,8 @@ class TestPageBasedFormats: class TestUniversalProcessor(IsolatedAsyncioTestCase): """Test universal decoder processor.""" - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_initialization( self, mock_producer, mock_consumer @@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert consumer_specs[0].name == "input" assert consumer_specs[0].schema == Document - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_custom_strategy( self, mock_producer, mock_consumer @@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert processor.partition_strategy == "hi_res" assert processor.section_strategy_name == "heading" - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_group_by_page(self, mock_producer, mock_consumer): """Test page grouping of elements.""" @@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert result[1][0] == 2 assert len(result[1][1]) == 1 - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_inline_non_page( @@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): }.get(name)) # Mock save_child_document and magic - processor.save_child_document = AsyncMock(return_value="mock-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-id") with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: mock_magic.from_buffer.return_value = "text/markdown" @@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert call_args.document_id.startswith("urn:section:") assert call_args.text == b"" - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_page_based( @@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): "triples": mock_triples_flow, }.get(name)) - processor.save_child_document = AsyncMock(return_value="mock-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-id") with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: mock_magic.from_buffer.return_value = "application/pdf" @@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): call_args = mock_output_flow.send.call_args_list[0][0][0] assert call_args.document_id.startswith("urn:page:") - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_images_stored_not_emitted( @@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): "triples": mock_triples_flow, }.get(name)) - processor.save_child_document = AsyncMock(return_value="mock-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-id") with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: mock_magic.from_buffer.return_value = "application/pdf" @@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert mock_triples_flow.send.call_count == 2 # save_child_document called twice (page + image) - assert processor.save_child_document.call_count == 2 + assert processor.librarian.save_child_document.call_count == 2 @patch('trustgraph.base.flow_processor.FlowProcessor.add_args') def test_add_args(self, mock_parent_add_args): diff --git a/tests/unit/test_pubsub/test_queue_naming.py b/tests/unit/test_pubsub/test_queue_naming.py index 1ee781d9..edd3dfca 100644 --- a/tests/unit/test_pubsub/test_queue_naming.py +++ b/tests/unit/test_pubsub/test_queue_naming.py @@ -109,6 +109,37 @@ class TestAddPubsubArgs: assert args.pubsub_backend == 'pulsar' +class TestAddPubsubArgsRabbitMQ: + + def test_rabbitmq_args_present(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([ + '--pubsub-backend', 'rabbitmq', + '--rabbitmq-host', 'myhost', + '--rabbitmq-port', '5673', + ]) + assert args.pubsub_backend == 'rabbitmq' + assert args.rabbitmq_host == 'myhost' + assert args.rabbitmq_port == 5673 + + def test_rabbitmq_defaults_container(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([]) + assert args.rabbitmq_host == 'rabbitmq' + assert args.rabbitmq_port == 5672 + assert args.rabbitmq_username == 'guest' + assert args.rabbitmq_password == 'guest' + assert args.rabbitmq_vhost == '/' + + def test_rabbitmq_standalone_defaults_to_localhost(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser, standalone=True) + args = parser.parse_args([]) + assert args.rabbitmq_host == 'localhost' + + class TestQueueDefinitions: """Verify the actual queue constants produce correct names.""" @@ -124,9 +155,9 @@ class TestQueueDefinitions: from trustgraph.schema.services.config import config_push_queue assert config_push_queue == 'state:tg:config' - def test_librarian_request_is_persistent(self): + def test_librarian_request(self): from trustgraph.schema.services.library import librarian_request_queue - assert librarian_request_queue.startswith('flow:') + assert librarian_request_queue == 'request:tg:librarian' def test_knowledge_request(self): from trustgraph.schema.knowledge.knowledge import knowledge_request_queue diff --git a/tests/unit/test_pubsub/test_rabbitmq_backend.py b/tests/unit/test_pubsub/test_rabbitmq_backend.py new file mode 100644 index 00000000..578db3b6 --- /dev/null +++ b/tests/unit/test_pubsub/test_rabbitmq_backend.py @@ -0,0 +1,107 @@ +""" +Unit tests for RabbitMQ backend — queue name mapping and factory dispatch. +Does not require a running RabbitMQ instance. +""" + +import pytest +import argparse + +pika = pytest.importorskip("pika", reason="pika not installed") + +from trustgraph.base.rabbitmq_backend import RabbitMQBackend +from trustgraph.base.pubsub import get_pubsub, add_pubsub_args + + +class TestRabbitMQMapQueueName: + + @pytest.fixture + def backend(self): + b = object.__new__(RabbitMQBackend) + return b + + def test_flow_is_durable(self, backend): + name, durable = backend.map_queue_name('flow:tg:text-completion-request') + assert durable is True + assert name == 'tg.flow.text-completion-request' + + def test_state_is_durable(self, backend): + name, durable = backend.map_queue_name('state:tg:config') + assert durable is True + assert name == 'tg.state.config' + + def test_request_is_not_durable(self, backend): + name, durable = backend.map_queue_name('request:tg:config') + assert durable is False + assert name == 'tg.request.config' + + def test_response_is_not_durable(self, backend): + name, durable = backend.map_queue_name('response:tg:librarian') + assert durable is False + assert name == 'tg.response.librarian' + + def test_custom_topicspace(self, backend): + name, durable = backend.map_queue_name('flow:prod:my-queue') + assert name == 'prod.flow.my-queue' + assert durable is True + + def test_no_colon_defaults_to_flow(self, backend): + name, durable = backend.map_queue_name('simple-queue') + assert name == 'tg.simple-queue' + assert durable is False + + def test_invalid_class_raises(self, backend): + with pytest.raises(ValueError, match="Invalid queue class"): + backend.map_queue_name('unknown:tg:topic') + + def test_flow_with_flow_suffix(self, backend): + """Queue names with flow suffix (e.g. :default) are preserved.""" + name, durable = backend.map_queue_name('request:tg:prompt:default') + assert name == 'tg.request.prompt:default' + + +class TestGetPubsubRabbitMQ: + + def test_factory_creates_rabbitmq_backend(self): + backend = get_pubsub(pubsub_backend='rabbitmq') + assert isinstance(backend, RabbitMQBackend) + + def test_factory_passes_config(self): + backend = get_pubsub( + pubsub_backend='rabbitmq', + rabbitmq_host='myhost', + rabbitmq_port=5673, + rabbitmq_username='user', + rabbitmq_password='pass', + rabbitmq_vhost='/test', + ) + assert isinstance(backend, RabbitMQBackend) + # Verify connection params were set + params = backend._connection_params + assert params.host == 'myhost' + assert params.port == 5673 + assert params.virtual_host == '/test' + + +class TestAddPubsubArgsRabbitMQ: + + def test_rabbitmq_args_present(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([ + '--pubsub-backend', 'rabbitmq', + '--rabbitmq-host', 'myhost', + '--rabbitmq-port', '5673', + ]) + assert args.pubsub_backend == 'rabbitmq' + assert args.rabbitmq_host == 'myhost' + assert args.rabbitmq_port == 5673 + + def test_rabbitmq_defaults_container(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([]) + assert args.rabbitmq_host == 'rabbitmq' + assert args.rabbitmq_port == 5672 + assert args.rabbitmq_username == 'guest' + assert args.rabbitmq_password == 'guest' + assert args.rabbitmq_vhost == '/' diff --git a/trustgraph-base/pyproject.toml b/trustgraph-base/pyproject.toml index 7d9f9219..b7b9757c 100644 --- a/trustgraph-base/pyproject.toml +++ b/trustgraph-base/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "prometheus-client", "requests", "python-logging-loki", + "pika", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index 396d64e0..c66598aa 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -22,8 +22,9 @@ logger = logging.getLogger(__name__) # Lower threshold provides progress feedback and resumability on slower connections CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024 -# Default chunk size (5MB - S3 multipart minimum) -DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024 +# Default chunk size (3MB - stays under broker message size limits +# after base64 encoding ~4MB) +DEFAULT_CHUNK_SIZE = 3 * 1024 * 1024 def to_value(x): diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 5a454279..24b6c1f0 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec from . subscriber_spec import SubscriberSpec from . request_response_spec import RequestResponseSpec from . llm_service import LlmService, LlmResult, LlmChunk +from . librarian_client import LibrarianClient from . chunking_service import ChunkingService from . embeddings_service import EmbeddingsService from . embeddings_client import EmbeddingsClientSpec diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 94bab278..7f7dbdcd 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -68,11 +68,12 @@ class AsyncProcessor: processor = self.id, flow = None, name = "config", ) - # Subscribe to config queue + # Subscribe to config queue — exclusive so every processor + # gets its own copy of config pushes (broadcast pattern) self.config_sub_task = Consumer( taskgroup = self.taskgroup, - backend = self.pubsub_backend, # Changed from client to backend + backend = self.pubsub_backend, subscriber = config_subscriber_id, flow = None, @@ -83,9 +84,8 @@ class AsyncProcessor: metrics = config_consumer_metrics, - # This causes new subscriptions to view the entire history of - # configuration - start_of_messages = True + start_of_messages = True, + consumer_type = 'exclusive', ) self.running = True diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py index 753378d4..d4bf4cd4 100644 --- a/trustgraph-base/trustgraph/base/chunking_service.py +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -7,23 +7,14 @@ fetching large document content. import asyncio import base64 import logging -import uuid from .flow_processor import FlowProcessor from .parameter_spec import ParameterSpec -from .consumer import Consumer -from .producer import Producer -from .metrics import ConsumerMetrics, ProducerMetrics - -from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata -from ..schema import librarian_request_queue, librarian_response_queue +from .librarian_client import LibrarianClient # Module logger logger = logging.getLogger(__name__) -default_librarian_request_queue = librarian_request_queue -default_librarian_response_queue = librarian_response_queue - class ChunkingService(FlowProcessor): """Base service for chunking processors with parameter specification support""" @@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor): ParameterSpec(name="chunk-overlap") ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue - ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor=id, flow=None, name="librarian-request" - ) - - self.librarian_request_producer = Producer( + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, - topic=librarian_request_q, - schema=LibrarianRequest, - metrics=librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor=id, flow=None, name="librarian-response" - ) - - self.librarian_response_consumer = Consumer( taskgroup=self.taskgroup, - backend=self.pubsub, - flow=None, - topic=librarian_response_q, - subscriber=f"{id}-librarian", - schema=LibrarianResponse, - handler=self.on_librarian_response, - metrics=librarian_response_metrics, ) - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} - logger.debug("ChunkingService initialized with parameter specifications") async def start(self): await super(ChunkingService, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="chunk", title=None, timeout=120): - """ - Save a child document (chunk) to the librarian. - - Args: - doc_id: ID for the new child document - parent_id: ID of the parent document - user: User ID - content: Document content (bytes or str) - document_type: Type of document ("chunk", etc.) - title: Optional title - timeout: Request timeout in seconds - - Returns: - The document ID on success - """ - request_id = str(uuid.uuid4()) - - if isinstance(content, str): - content = content.encode("utf-8") - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving chunk: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving chunk {doc_id}") + await self.librarian.start() async def get_document_text(self, doc): """ @@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor): """ if doc.document_id and not doc.text: logger.info(f"Fetching document {doc.document_id} from librarian...") - content = await self.fetch_document_content( + text = await self.librarian.fetch_document_text( document_id=doc.document_id, user=doc.metadata.user, ) - # Content is base64 encoded - if isinstance(content, str): - content = content.encode('utf-8') - text = base64.b64decode(content).decode("utf-8") logger.info(f"Fetched {len(text)} characters from librarian") return text else: @@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor): Extract chunk parameters from flow and return effective values Args: - msg: The message containing the document to chunk - consumer: The consumer spec - flow: The flow context - default_chunk_size: Default chunk size from processor config - default_chunk_overlap: Default chunk overlap from processor config + msg: The message being processed + consumer: The consumer instance + flow: The flow object containing parameters + default_chunk_size: Default chunk size if not configured + default_chunk_overlap: Default chunk overlap if not configured Returns: - tuple: (chunk_size, chunk_overlap) - effective values to use + tuple: (chunk_size, chunk_overlap) effective values """ - # Extract parameters from flow (flow-configurable parameters) - chunk_size = flow("chunk-size") - chunk_overlap = flow("chunk-overlap") - # Use provided values or fall back to defaults - effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size - effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap + chunk_size = default_chunk_size + chunk_overlap = default_chunk_overlap - logger.debug(f"Using chunk-size: {effective_chunk_size}") - logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}") + try: + cs = flow.parameters.get("chunk-size") + if cs is not None: + chunk_size = int(cs) + except Exception as e: + logger.warning(f"Could not parse chunk-size parameter: {e}") - return effective_chunk_size, effective_chunk_overlap + try: + co = flow.parameters.get("chunk-overlap") + if co is not None: + chunk_overlap = int(co) + except Exception as e: + logger.warning(f"Could not parse chunk-overlap parameter: {e}") - @staticmethod - def add_args(parser): - """Add chunking service arguments to parser""" - FlowProcessor.add_args(parser) - - parser.add_argument( - '--librarian-request-queue', - default=default_librarian_request_queue, - help=f'Librarian request queue (default: {default_librarian_request_queue})', - ) - - parser.add_argument( - '--librarian-response-queue', - default=default_librarian_response_queue, - help=f'Librarian response queue (default: {default_librarian_response_queue})', - ) \ No newline at end of file + return chunk_size, chunk_overlap diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 2a220312..9ae35d49 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -32,6 +32,7 @@ class Consumer: rate_limit_retry_time = 10, rate_limit_timeout = 7200, reconnect_time = 5, concurrency = 1, # Number of concurrent requests to handle + consumer_type = 'shared', ): self.taskgroup = taskgroup @@ -42,6 +43,8 @@ class Consumer: self.schema = schema self.handler = handler + self.consumer_type = consumer_type + self.rate_limit_retry_time = rate_limit_retry_time self.rate_limit_timeout = rate_limit_timeout @@ -93,33 +96,11 @@ class Consumer: if self.metrics: self.metrics.state("stopped") - try: - - logger.info(f"Subscribing to topic: {self.topic}") - - # Determine initial position - if self.start_of_messages: - initial_pos = 'earliest' - else: - initial_pos = 'latest' - - # Create consumer via backend - self.consumer = await asyncio.to_thread( - self.backend.create_consumer, - topic = self.topic, - subscription = self.subscriber, - schema = self.schema, - initial_position = initial_pos, - consumer_type = 'shared', - ) - - except Exception as e: - - logger.error(f"Consumer subscription exception: {e}", exc_info=True) - await asyncio.sleep(self.reconnect_time) - continue - - logger.info(f"Successfully subscribed to topic: {self.topic}") + # Determine initial position + if self.start_of_messages: + initial_pos = 'earliest' + else: + initial_pos = 'latest' if self.metrics: self.metrics.state("running") @@ -128,14 +109,30 @@ class Consumer: logger.info(f"Starting {self.concurrency} receiver threads") - async with asyncio.TaskGroup() as tg: - - tasks = [] - - for i in range(0, self.concurrency): - tasks.append( - tg.create_task(self.consume_from_queue()) + # Create one backend consumer per concurrent task. + # Each gets its own connection — required for backends + # like RabbitMQ where connections are not thread-safe. + consumers = [] + for i in range(self.concurrency): + try: + logger.info(f"Subscribing to topic: {self.topic} (worker {i})") + c = await asyncio.to_thread( + self.backend.create_consumer, + topic = self.topic, + subscription = self.subscriber, + schema = self.schema, + initial_position = initial_pos, + consumer_type = self.consumer_type, ) + consumers.append(c) + logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})") + except Exception as e: + logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True) + raise + + async with asyncio.TaskGroup() as tg: + for c in consumers: + tg.create_task(self.consume_from_queue(c)) if self.metrics: self.metrics.state("stopped") @@ -143,23 +140,31 @@ class Consumer: except Exception as e: logger.error(f"Consumer loop exception: {e}", exc_info=True) - self.consumer.unsubscribe() - self.consumer.close() - self.consumer = None + for c in consumers: + try: + c.unsubscribe() + c.close() + except Exception: + pass + consumers = [] await asyncio.sleep(self.reconnect_time) continue - if self.consumer: - self.consumer.unsubscribe() - self.consumer.close() + finally: + for c in consumers: + try: + c.unsubscribe() + c.close() + except Exception: + pass - async def consume_from_queue(self): + async def consume_from_queue(self, consumer): while self.running: try: msg = await asyncio.to_thread( - self.consumer.receive, + consumer.receive, timeout_millis=2000 ) except Exception as e: @@ -168,9 +173,9 @@ class Consumer: continue raise e - await self.handle_one_from_queue(msg) + await self.handle_one_from_queue(msg, consumer) - async def handle_one_from_queue(self, msg): + async def handle_one_from_queue(self, msg, consumer): expiry = time.time() + self.rate_limit_timeout @@ -183,7 +188,7 @@ class Consumer: # Message failed to be processed, this causes it to # be retried - self.consumer.negative_acknowledge(msg) + consumer.negative_acknowledge(msg) if self.metrics: self.metrics.process("error") @@ -206,7 +211,7 @@ class Consumer: logger.debug("Message processed successfully") # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) + consumer.acknowledge(msg) if self.metrics: self.metrics.process("success") @@ -233,7 +238,7 @@ class Consumer: # Message failed to be processed, this causes it to # be retried - self.consumer.negative_acknowledge(msg) + consumer.negative_acknowledge(msg) if self.metrics: self.metrics.process("error") diff --git a/trustgraph-base/trustgraph/base/librarian_client.py b/trustgraph-base/trustgraph/base/librarian_client.py new file mode 100644 index 00000000..6191cff8 --- /dev/null +++ b/trustgraph-base/trustgraph/base/librarian_client.py @@ -0,0 +1,246 @@ +""" +Shared librarian client for services that need to communicate +with the librarian via pub/sub. + +Provides request-response and streaming operations over the message +broker, with proper support for large documents via stream-document. + +Usage: + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params + ) + await self.librarian.start() + content = await self.librarian.fetch_document_content(doc_id, user) +""" + +import asyncio +import base64 +import logging +import uuid + +from .consumer import Consumer +from .producer import Producer +from .metrics import ConsumerMetrics, ProducerMetrics + +from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ..schema import librarian_request_queue, librarian_response_queue + +logger = logging.getLogger(__name__) + + +class LibrarianClient: + """Client for librarian request-response over the message broker.""" + + def __init__(self, id, backend, taskgroup, **params): + + librarian_request_q = params.get( + "librarian_request_queue", librarian_request_queue, + ) + librarian_response_q = params.get( + "librarian_response_queue", librarian_response_queue, + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request", + ) + + self._producer = Producer( + backend=backend, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response", + ) + + self._consumer = Consumer( + taskgroup=taskgroup, + backend=backend, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self._on_response, + metrics=librarian_response_metrics, + consumer_type='exclusive', + ) + + # Single-response requests: request_id -> asyncio.Future + self._pending = {} + # Streaming requests: request_id -> asyncio.Queue + self._streams = {} + + async def start(self): + """Start the librarian producer and consumer.""" + await self._producer.start() + await self._consumer.start() + + async def _on_response(self, msg, consumer, flow): + """Route librarian responses to the right waiter.""" + response = msg.value() + request_id = msg.properties().get("id") + + if not request_id: + return + + if request_id in self._pending: + future = self._pending.pop(request_id) + future.set_result(response) + elif request_id in self._streams: + await self._streams[request_id].put(response) + + async def request(self, request, timeout=120): + """Send a request to the librarian and wait for a single response.""" + request_id = str(uuid.uuid4()) + + future = asyncio.get_event_loop().create_future() + self._pending[request_id] = future + + try: + await self._producer.send( + request, properties={"id": request_id}, + ) + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: " + f"{response.error.message}" + ) + + return response + + except asyncio.TimeoutError: + self._pending.pop(request_id, None) + raise RuntimeError("Timeout waiting for librarian response") + + async def stream(self, request, timeout=120): + """Send a request and collect streamed response chunks.""" + request_id = str(uuid.uuid4()) + + q = asyncio.Queue() + self._streams[request_id] = q + + try: + await self._producer.send( + request, properties={"id": request_id}, + ) + + chunks = [] + while True: + response = await asyncio.wait_for(q.get(), timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: " + f"{response.error.message}" + ) + + chunks.append(response) + + if response.is_final: + break + + return chunks + + except asyncio.TimeoutError: + self._streams.pop(request_id, None) + raise RuntimeError("Timeout waiting for librarian stream") + finally: + self._streams.pop(request_id, None) + + async def fetch_document_content(self, document_id, user, timeout=120): + """Fetch document content using streaming. + + Returns base64-encoded content. Caller is responsible for decoding. + """ + req = LibrarianRequest( + operation="stream-document", + document_id=document_id, + user=user, + ) + chunks = await self.stream(req, timeout=timeout) + + # Decode each chunk's base64 to raw bytes, concatenate, + # re-encode for the caller. + raw = b"" + for chunk in chunks: + if chunk.content: + if isinstance(chunk.content, bytes): + raw += base64.b64decode(chunk.content) + else: + raw += base64.b64decode( + chunk.content.encode("utf-8") + ) + + return base64.b64encode(raw) + + async def fetch_document_text(self, document_id, user, timeout=120): + """Fetch document content and decode as UTF-8 text.""" + content = await self.fetch_document_content( + document_id, user, timeout=timeout, + ) + return base64.b64decode(content).decode("utf-8") + + async def fetch_document_metadata(self, document_id, user, timeout=120): + """Fetch document metadata from the librarian.""" + req = LibrarianRequest( + operation="get-document-metadata", + document_id=document_id, + user=user, + ) + response = await self.request(req, timeout=timeout) + return response.document_metadata + + async def save_child_document(self, doc_id, parent_id, user, content, + document_type="chunk", title=None, + kind="text/plain", timeout=120): + """Save a child document to the librarian.""" + if isinstance(content, str): + content = content.encode("utf-8") + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind=kind, + title=title or doc_id, + parent_id=parent_id, + document_type=document_type, + ) + + req = LibrarianRequest( + operation="add-child-document", + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + ) + + await self.request(req, timeout=timeout) + return doc_id + + async def save_document(self, doc_id, user, content, title=None, + document_type="answer", kind="text/plain", + timeout=120): + """Save a document to the librarian.""" + if isinstance(content, str): + content = content.encode("utf-8") + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind=kind, + title=title or doc_id, + document_type=document_type, + ) + + req = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + user=user, + ) + + await self.request(req, timeout=timeout) + return doc_id diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py index 04734f28..8fe532d8 100644 --- a/trustgraph-base/trustgraph/base/pubsub.py +++ b/trustgraph-base/trustgraph/base/pubsub.py @@ -8,6 +8,12 @@ logger = logging.getLogger(__name__) DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None) +DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq') +DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672')) +DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest') +DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest') +DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/') + def get_pubsub(**config): """ @@ -29,6 +35,15 @@ def get_pubsub(**config): api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY), listener=config.get('pulsar_listener'), ) + elif backend_type == 'rabbitmq': + from .rabbitmq_backend import RabbitMQBackend + return RabbitMQBackend( + host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST), + port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT), + username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME), + password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD), + vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST), + ) else: raise ValueError(f"Unknown pub/sub backend: {backend_type}") @@ -44,8 +59,9 @@ def add_pubsub_args(parser, standalone=False): standalone: If True, default host is localhost (for CLI tools that run outside containers) """ - host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST - listener_default = 'localhost' if standalone else None + pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST + pulsar_listener = 'localhost' if standalone else None + rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST parser.add_argument( '--pubsub-backend', @@ -53,10 +69,11 @@ def add_pubsub_args(parser, standalone=False): help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', ) + # Pulsar options parser.add_argument( '-p', '--pulsar-host', - default=host, - help=f'Pulsar host (default: {host})', + default=pulsar_host, + help=f'Pulsar host (default: {pulsar_host})', ) parser.add_argument( @@ -67,6 +84,38 @@ def add_pubsub_args(parser, standalone=False): parser.add_argument( '--pulsar-listener', - default=listener_default, - help=f'Pulsar listener (default: {listener_default or "none"})', + default=pulsar_listener, + help=f'Pulsar listener (default: {pulsar_listener or "none"})', + ) + + # RabbitMQ options + parser.add_argument( + '--rabbitmq-host', + default=rabbitmq_host, + help=f'RabbitMQ host (default: {rabbitmq_host})', + ) + + parser.add_argument( + '--rabbitmq-port', + type=int, + default=DEFAULT_RABBITMQ_PORT, + help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})', + ) + + parser.add_argument( + '--rabbitmq-username', + default=DEFAULT_RABBITMQ_USERNAME, + help='RabbitMQ username', + ) + + parser.add_argument( + '--rabbitmq-password', + default=DEFAULT_RABBITMQ_PASSWORD, + help='RabbitMQ password', + ) + + parser.add_argument( + '--rabbitmq-vhost', + default=DEFAULT_RABBITMQ_VHOST, + help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})', ) diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index 677f2527..9480243e 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -9,122 +9,14 @@ import pulsar import _pulsar import json import logging -import base64 -import types -from dataclasses import asdict, is_dataclass -from typing import Any, get_type_hints +from typing import Any from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message +from .serialization import dataclass_to_dict, dict_to_dataclass logger = logging.getLogger(__name__) -def dataclass_to_dict(obj: Any) -> dict: - """ - Recursively convert a dataclass to a dictionary, handling None values and bytes. - - None values are excluded from the dictionary (not serialized). - Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior). - Handles nested dataclasses, lists, and dictionaries recursively. - """ - if obj is None: - return None - - # Handle bytes - decode to UTF-8 for JSON serialization - if isinstance(obj, bytes): - return obj.decode('utf-8') - - # Handle dataclass - convert to dict then recursively process all values - if is_dataclass(obj): - result = {} - for key, value in asdict(obj).items(): - result[key] = dataclass_to_dict(value) if value is not None else None - return result - - # Handle list - recursively process all items - if isinstance(obj, list): - return [dataclass_to_dict(item) for item in obj] - - # Handle dict - recursively process all values - if isinstance(obj, dict): - return {k: dataclass_to_dict(v) for k, v in obj.items()} - - # Return primitive types as-is - return obj - - -def dict_to_dataclass(data: dict, cls: type) -> Any: - """ - Convert a dictionary back to a dataclass instance. - - Handles nested dataclasses and missing fields. - Uses get_type_hints() to resolve forward references (string annotations). - """ - if data is None: - return None - - if not is_dataclass(cls): - return data - - # Get field types from the dataclass, resolving forward references - # get_type_hints() evaluates string annotations like "Triple | None" - try: - field_types = get_type_hints(cls) - except Exception: - # Fallback if get_type_hints fails (shouldn't happen normally) - field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} - kwargs = {} - - for key, value in data.items(): - if key in field_types: - field_type = field_types[key] - - # Handle modern union types (X | Y) - if isinstance(field_type, types.UnionType): - # Check if it's Optional (X | None) - if type(None) in field_type.__args__: - # Get the non-None type - actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) - if actual_type and is_dataclass(actual_type) and isinstance(value, dict): - kwargs[key] = dict_to_dataclass(value, actual_type) - else: - kwargs[key] = value - else: - kwargs[key] = value - # Check if this is a generic type (list, dict, etc.) - elif hasattr(field_type, '__origin__'): - # Handle list[T] - if field_type.__origin__ == list: - item_type = field_type.__args__[0] if field_type.__args__ else None - if item_type and is_dataclass(item_type) and isinstance(value, list): - kwargs[key] = [ - dict_to_dataclass(item, item_type) if isinstance(item, dict) else item - for item in value - ] - else: - kwargs[key] = value - # Handle old-style Optional[T] (which is Union[T, None]) - elif hasattr(field_type, '__args__') and type(None) in field_type.__args__: - # Get the non-None type from Union - actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) - if actual_type and is_dataclass(actual_type) and isinstance(value, dict): - kwargs[key] = dict_to_dataclass(value, actual_type) - else: - kwargs[key] = value - else: - kwargs[key] = value - # Handle direct dataclass fields - elif is_dataclass(field_type) and isinstance(value, dict): - kwargs[key] = dict_to_dataclass(value, field_type) - # Handle bytes fields (UTF-8 encoded strings from JSON) - elif field_type == bytes and isinstance(value, str): - kwargs[key] = value.encode('utf-8') - else: - kwargs[key] = value - - return cls(**kwargs) - - class PulsarMessage: """Wrapper for Pulsar messages to match Message protocol.""" diff --git a/trustgraph-base/trustgraph/base/rabbitmq_backend.py b/trustgraph-base/trustgraph/base/rabbitmq_backend.py new file mode 100644 index 00000000..a80efbaf --- /dev/null +++ b/trustgraph-base/trustgraph/base/rabbitmq_backend.py @@ -0,0 +1,390 @@ +""" +RabbitMQ backend implementation for pub/sub abstraction. + +Uses a single topic exchange per topicspace. The logical queue name +becomes the routing key. Consumer behavior is determined by the +subscription name: + +- Same subscription + same topic = shared queue (competing consumers) +- Different subscriptions = separate queues (broadcast / fan-out) + +This mirrors Pulsar's subscription model using idiomatic RabbitMQ. + +Architecture: + Producer --> [tg exchange] --routing key--> [named queue] --> Consumer + --routing key--> [named queue] --> Consumer + --routing key--> [exclusive q] --> Subscriber + +Uses basic_consume (push) instead of basic_get (polling) for +efficient message delivery. +""" + +import json +import time +import logging +import queue +import threading +import pika +from typing import Any + +from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message +from .serialization import dataclass_to_dict, dict_to_dataclass + +logger = logging.getLogger(__name__) + + +class RabbitMQMessage: + """Wrapper for RabbitMQ messages to match Message protocol.""" + + def __init__(self, method, properties, body, schema_cls): + self._method = method + self._properties = properties + self._body = body + self._schema_cls = schema_cls + self._value = None + + def value(self) -> Any: + """Deserialize and return the message value as a dataclass.""" + if self._value is None: + data_dict = json.loads(self._body.decode('utf-8')) + self._value = dict_to_dataclass(data_dict, self._schema_cls) + return self._value + + def properties(self) -> dict: + """Return message properties from AMQP headers.""" + headers = self._properties.headers or {} + return dict(headers) + + +class RabbitMQBackendProducer: + """Publishes messages to a topic exchange with a routing key. + + Uses thread-local connections so each thread gets its own + connection/channel. This avoids wire corruption from concurrent + threads writing to the same socket (pika is not thread-safe). + """ + + def __init__(self, connection_params, exchange_name, routing_key, + durable): + self._connection_params = connection_params + self._exchange_name = exchange_name + self._routing_key = routing_key + self._durable = durable + self._local = threading.local() + + def _get_channel(self): + """Get or create a thread-local connection and channel.""" + conn = getattr(self._local, 'connection', None) + chan = getattr(self._local, 'channel', None) + + if conn is None or not conn.is_open or chan is None or not chan.is_open: + # Close stale connection if any + if conn is not None: + try: + conn.close() + except Exception: + pass + + conn = pika.BlockingConnection(self._connection_params) + chan = conn.channel() + chan.exchange_declare( + exchange=self._exchange_name, + exchange_type='topic', + durable=True, + ) + self._local.connection = conn + self._local.channel = chan + + return chan + + def send(self, message: Any, properties: dict = {}) -> None: + data_dict = dataclass_to_dict(message) + json_data = json.dumps(data_dict) + + amqp_properties = pika.BasicProperties( + delivery_mode=2 if self._durable else 1, + content_type='application/json', + headers=properties if properties else None, + ) + + for attempt in range(2): + try: + channel = self._get_channel() + channel.basic_publish( + exchange=self._exchange_name, + routing_key=self._routing_key, + body=json_data.encode('utf-8'), + properties=amqp_properties, + ) + return + except Exception as e: + logger.warning( + f"RabbitMQ send failed (attempt {attempt + 1}): {e}" + ) + # Force reconnect on next attempt + self._local.connection = None + self._local.channel = None + if attempt == 1: + raise + + def flush(self) -> None: + pass + + def close(self) -> None: + """Close the thread-local connection if any.""" + conn = getattr(self._local, 'connection', None) + if conn is not None: + try: + conn.close() + except Exception: + pass + self._local.connection = None + self._local.channel = None + + +class RabbitMQBackendConsumer: + """Consumes from a queue bound to a topic exchange. + + Uses basic_consume (push model) with messages delivered to an + internal thread-safe queue. process_data_events() drives both + message delivery and heartbeat processing. + """ + + def __init__(self, connection_params, exchange_name, routing_key, + queue_name, schema_cls, durable, exclusive=False, + auto_delete=False): + self._connection_params = connection_params + self._exchange_name = exchange_name + self._routing_key = routing_key + self._queue_name = queue_name + self._schema_cls = schema_cls + self._durable = durable + self._exclusive = exclusive + self._auto_delete = auto_delete + self._connection = None + self._channel = None + self._consumer_tag = None + self._incoming = queue.Queue() + + def _connect(self): + self._connection = pika.BlockingConnection(self._connection_params) + self._channel = self._connection.channel() + + # Declare the topic exchange + self._channel.exchange_declare( + exchange=self._exchange_name, + exchange_type='topic', + durable=True, + ) + + # Declare the queue — anonymous if exclusive + result = self._channel.queue_declare( + queue=self._queue_name, + durable=self._durable, + exclusive=self._exclusive, + auto_delete=self._auto_delete, + ) + # Capture actual name (important for anonymous queues where name='') + self._queue_name = result.method.queue + + self._channel.queue_bind( + queue=self._queue_name, + exchange=self._exchange_name, + routing_key=self._routing_key, + ) + + self._channel.basic_qos(prefetch_count=1) + + # Register push-based consumer + self._consumer_tag = self._channel.basic_consume( + queue=self._queue_name, + on_message_callback=self._on_message, + auto_ack=False, + ) + + def _on_message(self, channel, method, properties, body): + """Callback invoked by pika when a message arrives.""" + self._incoming.put((method, properties, body)) + + def _is_alive(self): + return ( + self._connection is not None + and self._connection.is_open + and self._channel is not None + and self._channel.is_open + ) + + def receive(self, timeout_millis: int = 2000) -> Message: + """Receive a message. Raises TimeoutError if none available.""" + if not self._is_alive(): + self._connect() + + timeout_seconds = timeout_millis / 1000.0 + deadline = time.monotonic() + timeout_seconds + + while time.monotonic() < deadline: + # Check if a message was already delivered + try: + method, properties, body = self._incoming.get_nowait() + return RabbitMQMessage( + method, properties, body, self._schema_cls, + ) + except queue.Empty: + pass + + # Drive pika's I/O — delivers messages and processes heartbeats + remaining = deadline - time.monotonic() + if remaining > 0: + self._connection.process_data_events( + time_limit=min(0.1, remaining), + ) + + raise TimeoutError("No message received within timeout") + + def acknowledge(self, message: Message) -> None: + if isinstance(message, RabbitMQMessage) and message._method: + self._channel.basic_ack( + delivery_tag=message._method.delivery_tag, + ) + + def negative_acknowledge(self, message: Message) -> None: + if isinstance(message, RabbitMQMessage) and message._method: + self._channel.basic_nack( + delivery_tag=message._method.delivery_tag, + requeue=True, + ) + + def unsubscribe(self) -> None: + if self._consumer_tag and self._channel and self._channel.is_open: + try: + self._channel.basic_cancel(self._consumer_tag) + except Exception: + pass + self._consumer_tag = None + + def close(self) -> None: + self.unsubscribe() + try: + if self._channel and self._channel.is_open: + self._channel.close() + except Exception: + pass + try: + if self._connection and self._connection.is_open: + self._connection.close() + except Exception: + pass + self._channel = None + self._connection = None + + +class RabbitMQBackend: + """RabbitMQ pub/sub backend using a topic exchange per topicspace.""" + + def __init__(self, host='localhost', port=5672, username='guest', + password='guest', vhost='/'): + self._connection_params = pika.ConnectionParameters( + host=host, + port=port, + virtual_host=vhost, + credentials=pika.PlainCredentials(username, password), + ) + logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}") + + def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]: + """ + Parse queue identifier into exchange, routing key, and durability. + + Format: class:topicspace:topic + Returns: (exchange_name, routing_key, class, durable) + """ + if ':' not in queue_id: + return 'tg', queue_id, 'flow', False + + parts = queue_id.split(':', 2) + if len(parts) != 3: + raise ValueError( + f"Invalid queue format: {queue_id}, " + f"expected class:topicspace:topic" + ) + + cls, topicspace, topic = parts + + if cls in ('flow', 'state'): + durable = True + elif cls in ('request', 'response'): + durable = False + else: + raise ValueError( + f"Invalid queue class: {cls}, " + f"expected flow, request, response, or state" + ) + + # Exchange per topicspace, routing key includes class + exchange_name = topicspace + routing_key = f"{cls}.{topic}" + + return exchange_name, routing_key, cls, durable + + # Keep map_queue_name for backward compatibility with tests + def map_queue_name(self, queue_id: str) -> tuple[str, bool]: + exchange, routing_key, cls, durable = self._parse_queue_id(queue_id) + return f"{exchange}.{routing_key}", durable + + def create_producer(self, topic: str, schema: type, + **options) -> BackendProducer: + exchange, routing_key, cls, durable = self._parse_queue_id(topic) + logger.debug( + f"Creating producer: exchange={exchange}, " + f"routing_key={routing_key}" + ) + return RabbitMQBackendProducer( + self._connection_params, exchange, routing_key, durable, + ) + + def create_consumer(self, topic: str, subscription: str, schema: type, + initial_position: str = 'latest', + consumer_type: str = 'shared', + **options) -> BackendConsumer: + """Create a consumer with a queue bound to the topic exchange. + + consumer_type='shared': Named durable queue. Multiple consumers + with the same subscription compete (round-robin). + consumer_type='exclusive': Anonymous ephemeral queue. Each + consumer gets its own copy of every message (broadcast). + """ + exchange, routing_key, cls, durable = self._parse_queue_id(topic) + + if consumer_type == 'exclusive' and cls == 'state': + # State broadcast: named durable queue per subscriber. + # Retains messages so late-starting processors see current state. + queue_name = f"{exchange}.{routing_key}.{subscription}" + queue_durable = True + exclusive = False + auto_delete = False + elif consumer_type == 'exclusive': + # Broadcast: anonymous queue, auto-deleted on disconnect + queue_name = '' + queue_durable = False + exclusive = True + auto_delete = True + else: + # Shared: named queue, competing consumers + queue_name = f"{exchange}.{routing_key}.{subscription}" + queue_durable = durable + exclusive = False + auto_delete = False + + logger.debug( + f"Creating consumer: exchange={exchange}, " + f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, " + f"type={consumer_type}" + ) + + return RabbitMQBackendConsumer( + self._connection_params, exchange, routing_key, + queue_name, schema, queue_durable, exclusive, auto_delete, + ) + + def close(self) -> None: + pass diff --git a/trustgraph-base/trustgraph/base/serialization.py b/trustgraph-base/trustgraph/base/serialization.py new file mode 100644 index 00000000..6fd3ca62 --- /dev/null +++ b/trustgraph-base/trustgraph/base/serialization.py @@ -0,0 +1,115 @@ +""" +JSON serialization helpers for dataclass ↔ dict conversion. + +Used by pub/sub backends that use JSON as their wire format. +""" + +import types +from dataclasses import asdict, is_dataclass +from typing import Any, get_type_hints + + +def dataclass_to_dict(obj: Any) -> dict: + """ + Recursively convert a dataclass to a dictionary, handling None values and bytes. + + None values are excluded from the dictionary (not serialized). + Bytes values are decoded as UTF-8 strings for JSON serialization. + Handles nested dataclasses, lists, and dictionaries recursively. + """ + if obj is None: + return None + + # Handle bytes - decode to UTF-8 for JSON serialization + if isinstance(obj, bytes): + return obj.decode('utf-8') + + # Handle dataclass - convert to dict then recursively process all values + if is_dataclass(obj): + result = {} + for key, value in asdict(obj).items(): + result[key] = dataclass_to_dict(value) if value is not None else None + return result + + # Handle list - recursively process all items + if isinstance(obj, list): + return [dataclass_to_dict(item) for item in obj] + + # Handle dict - recursively process all values + if isinstance(obj, dict): + return {k: dataclass_to_dict(v) for k, v in obj.items()} + + # Return primitive types as-is + return obj + + +def dict_to_dataclass(data: dict, cls: type) -> Any: + """ + Convert a dictionary back to a dataclass instance. + + Handles nested dataclasses and missing fields. + Uses get_type_hints() to resolve forward references (string annotations). + """ + if data is None: + return None + + if not is_dataclass(cls): + return data + + # Get field types from the dataclass, resolving forward references + # get_type_hints() evaluates string annotations like "Triple | None" + try: + field_types = get_type_hints(cls) + except Exception: + # Fallback if get_type_hints fails (shouldn't happen normally) + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + + for key, value in data.items(): + if key in field_types: + field_type = field_types[key] + + # Handle modern union types (X | Y) + if isinstance(field_type, types.UnionType): + # Check if it's Optional (X | None) + if type(None) in field_type.__args__: + # Get the non-None type + actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) + if actual_type and is_dataclass(actual_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, actual_type) + else: + kwargs[key] = value + else: + kwargs[key] = value + # Check if this is a generic type (list, dict, etc.) + elif hasattr(field_type, '__origin__'): + # Handle list[T] + if field_type.__origin__ == list: + item_type = field_type.__args__[0] if field_type.__args__ else None + if item_type and is_dataclass(item_type) and isinstance(value, list): + kwargs[key] = [ + dict_to_dataclass(item, item_type) if isinstance(item, dict) else item + for item in value + ] + else: + kwargs[key] = value + # Handle old-style Optional[T] (which is Union[T, None]) + elif hasattr(field_type, '__args__') and type(None) in field_type.__args__: + # Get the non-None type from Union + actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) + if actual_type and is_dataclass(actual_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, actual_type) + else: + kwargs[key] = value + else: + kwargs[key] = value + # Handle direct dataclass fields + elif is_dataclass(field_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, field_type) + # Handle bytes fields (UTF-8 encoded strings from JSON) + elif field_type == bytes and isinstance(value, str): + kwargs[key] = value.encode('utf-8') + else: + kwargs[key] = value + + return cls(**kwargs) diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index b0d90507..36948131 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -51,7 +51,7 @@ class Subscriber: topic=self.topic, subscription=self.subscription, schema=self.schema, - consumer_type='shared', + consumer_type='exclusive', ) self.task = asyncio.create_task(self.run()) diff --git a/trustgraph-base/trustgraph/clients/base.py b/trustgraph-base/trustgraph/clients/base.py index a71ba84e..cd4ad72e 100644 --- a/trustgraph-base/trustgraph/clients/base.py +++ b/trustgraph-base/trustgraph/clients/base.py @@ -18,9 +18,7 @@ class BaseClient: output_queue=None, input_schema=None, output_schema=None, - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key=None, - listener=None, + **pubsub_config, ): if input_queue == None: raise RuntimeError("Need input_queue") @@ -32,12 +30,7 @@ class BaseClient: subscriber = str(uuid.uuid4()) # Create backend using factory - self.backend = get_pubsub( - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, - pulsar_listener=listener, - pubsub_backend='pulsar' - ) + self.backend = get_pubsub(**pubsub_config) self.producer = self.backend.create_producer( topic=input_queue, diff --git a/trustgraph-base/trustgraph/clients/config_client.py b/trustgraph-base/trustgraph/clients/config_client.py index daadf652..78b62688 100644 --- a/trustgraph-base/trustgraph/clients/config_client.py +++ b/trustgraph-base/trustgraph/clients/config_client.py @@ -33,9 +33,7 @@ class ConfigClient(BaseClient): subscriber=None, input_queue=None, output_queue=None, - pulsar_host="pulsar://pulsar:6650", - listener=None, - pulsar_api_key=None, + **pubsub_config, ): if input_queue == None: @@ -48,11 +46,9 @@ class ConfigClient(BaseClient): subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, input_schema=ConfigRequest, output_schema=ConfigResponse, - listener=listener, + **pubsub_config, ) def get(self, keys, timeout=300): diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index 51d0d5a5..f5d4592c 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -24,10 +24,13 @@ from ..core.metadata import Metadata # <- (document_metadata) # <- (error) -# get-document-content +# get-document-content [DEPRECATED — use stream-document instead] # -> (document_id) # <- (content) # <- (error) +# NOTE: Returns entire document in a single message. Fails for documents +# exceeding the broker's max message size. Use stream-document which +# returns content in chunks. # add-processing # -> (processing_id, processing_metadata) @@ -220,5 +223,5 @@ class LibrarianResponse: # FIXME: Is this right? Using persistence on librarian so that # message chunking works -librarian_request_queue = queue('librarian-request', cls='flow') -librarian_response_queue = queue('librarian-response', cls='flow') +librarian_request_queue = queue('librarian', cls='request') +librarian_response_queue = queue('librarian', cls='response') diff --git a/trustgraph-cli/trustgraph/cli/dump_queues.py b/trustgraph-cli/trustgraph/cli/dump_queues.py index eb7898c2..95be8529 100644 --- a/trustgraph-cli/trustgraph/cli/dump_queues.py +++ b/trustgraph-cli/trustgraph/cli/dump_queues.py @@ -354,10 +354,8 @@ IMPORTANT: output_file=args.output, subscriber_name=args.subscriber, append_mode=args.append, - pubsub_backend=args.pubsub_backend, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener, + **{k: v for k, v in vars(args).items() + if k not in ('queues', 'output', 'subscriber', 'append')}, )) except KeyboardInterrupt: # Already handled in async_main diff --git a/trustgraph-cli/trustgraph/cli/init_trustgraph.py b/trustgraph-cli/trustgraph/cli/init_trustgraph.py index 02456b1c..514dc75b 100644 --- a/trustgraph-cli/trustgraph/cli/init_trustgraph.py +++ b/trustgraph-cli/trustgraph/cli/init_trustgraph.py @@ -1,5 +1,8 @@ """ -Initialises Pulsar with Trustgraph tenant / namespaces & policy. +Initialises TrustGraph pub/sub infrastructure and pushes initial config. + +For Pulsar: creates tenant, namespaces, and retention policies. +For RabbitMQ: queues are auto-declared, so only config push is needed. """ import requests @@ -8,10 +11,11 @@ import argparse import json from trustgraph.clients.config_client import ConfigClient +from trustgraph.base.pubsub import add_pubsub_args default_pulsar_admin_url = "http://pulsar:8080" -default_pulsar_host = "pulsar://pulsar:6650" -subscriber = "tg-init-pulsar" +subscriber = "tg-init-pubsub" + def get_clusters(url): @@ -65,12 +69,11 @@ def ensure_namespace(url, tenant, namespace, config): print(f"Namespace {tenant}/{namespace} created.", flush=True) -def ensure_config(config, pulsar_host, pulsar_api_key): +def ensure_config(config, **pubsub_config): cli = ConfigClient( subscriber=subscriber, - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, + **pubsub_config, ) while True: @@ -115,11 +118,9 @@ def ensure_config(config, pulsar_host, pulsar_api_key): time.sleep(2) print("Retrying...", flush=True) continue - -def init( - pulsar_admin_url, pulsar_host, pulsar_api_key, tenant, - config, config_file, -): + +def init_pulsar(pulsar_admin_url, tenant): + """Pulsar-specific setup: create tenant, namespaces, retention policies.""" clusters = get_clusters(pulsar_admin_url) @@ -145,17 +146,21 @@ def init( } }) - if config is not None: + +def push_config(config_json, config_file, **pubsub_config): + """Push initial config if provided.""" + + if config_json is not None: try: print("Decoding config...", flush=True) - dec = json.loads(config) + dec = json.loads(config_json) print("Decoded.", flush=True) except Exception as e: print("Exception:", e, flush=True) raise e - ensure_config(dec, pulsar_host, pulsar_api_key) + ensure_config(dec, **pubsub_config) elif config_file is not None: @@ -167,11 +172,12 @@ def init( print("Exception:", e, flush=True) raise e - ensure_config(dec, pulsar_host, pulsar_api_key) + ensure_config(dec, **pubsub_config) else: print("No config to update.", flush=True) + def main(): parser = argparse.ArgumentParser( @@ -180,22 +186,11 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-admin-url', + '--pulsar-admin-url', default=default_pulsar_admin_url, help=f'Pulsar admin URL (default: {default_pulsar_admin_url})', ) - parser.add_argument( - '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '--pulsar-api-key', - help=f'Pulsar API key', - ) - parser.add_argument( '-c', '--config', help=f'Initial configuration to load', @@ -212,18 +207,43 @@ def main(): help=f'Tenant (default: tg)', ) + add_pubsub_args(parser) + args = parser.parse_args() + backend_type = args.pubsub_backend + + # Extract pubsub config from args + pubsub_config = { + k: v for k, v in vars(args).items() + if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant') + } + while True: try: - print(flush=True) - print( - f"Initialising with Pulsar {args.pulsar_admin_url}...", - flush=True + # Pulsar-specific setup (tenants, namespaces) + if backend_type == 'pulsar': + print(flush=True) + print( + f"Initialising Pulsar at {args.pulsar_admin_url}...", + flush=True, + ) + init_pulsar(args.pulsar_admin_url, args.tenant) + else: + print(flush=True) + print( + f"Using {backend_type} backend (no admin setup needed).", + flush=True, + ) + + # Push config (works with any backend) + push_config( + args.config, args.config_file, + **pubsub_config, ) - init(**vars(args)) + print("Initialisation complete.", flush=True) break @@ -236,4 +256,4 @@ def main(): print("Will retry...", flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/monitor_prompts.py b/trustgraph-cli/trustgraph/cli/monitor_prompts.py index c3b71afb..0cfe68ac 100644 --- a/trustgraph-cli/trustgraph/cli/monitor_prompts.py +++ b/trustgraph-cli/trustgraph/cli/monitor_prompts.py @@ -316,10 +316,8 @@ def main(): queue_type=args.queue_type, max_lines=args.max_lines, max_width=args.max_width, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener, - pubsub_backend=args.pubsub_backend, + **{k: v for k, v in vars(args).items() + if k not in ('flow', 'queue_type', 'max_lines', 'max_width')}, )) except KeyboardInterrupt: pass diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index 64d58457..df2c58bd 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -133,7 +133,7 @@ class Processor(ChunkingService): chunk_length = len(chunk.page_content) # Save chunk to librarian as child document - await self.save_child_document( + await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index 4302250e..3e1161bf 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -131,7 +131,7 @@ class Processor(ChunkingService): chunk_length = len(chunk.page_content) # Save chunk to librarian as child document - await self.save_child_document( + await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index 8685aa61..40b8c566 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -9,20 +9,16 @@ for large documents. from pypdf import PdfWriter, PdfReader from io import BytesIO -import asyncio import base64 import uuid import os from mistralai import Mistral -from mistralai.models import OCRResponse from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, derived_entity_triples, @@ -102,42 +98,10 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor = id, flow = None, name = "librarian-request" - ) - - self.librarian_request_producer = Producer( - backend = self.pubsub, - topic = librarian_request_q, - schema = LibrarianRequest, - metrics = librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor = id, flow = None, name = "librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = librarian_response_q, - subscriber = f"{id}-librarian", - schema = LibrarianResponse, - handler = self.on_librarian_response, - metrics = librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} if api_key is None: raise RuntimeError("Mistral API key not specified") @@ -151,132 +115,7 @@ class Processor(FlowProcessor): async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_metadata(self, document_id, user, timeout=120): - """ - Fetch document metadata from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.document_metadata - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching metadata for {document_id}") - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, timeout=120): - """ - Save a child document to the librarian. - """ - request_id = str(uuid.uuid4()) - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving child document: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving child document {doc_id}") + await self.librarian.start() def ocr(self, blob): """ @@ -359,7 +198,7 @@ class Processor(FlowProcessor): # Check MIME type if fetching from librarian if v.document_id: - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) @@ -374,7 +213,7 @@ class Processor(FlowProcessor): # Get PDF content - fetch from librarian or use inline data if v.document_id: logger.info(f"Fetching document {v.document_id} from librarian...") - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, ) @@ -401,7 +240,7 @@ class Processor(FlowProcessor): page_content = markdown.encode("utf-8") # Save page as child document in librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 38ca0603..d0061afd 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -7,20 +7,16 @@ Supports both inline document data and fetching from librarian via Pulsar for large documents. """ -import asyncio import os import tempfile import base64 import logging -import uuid from langchain_community.document_loaders import PyPDFLoader from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, derived_entity_triples, @@ -74,187 +70,16 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor = id, flow = None, name = "librarian-request" - ) - - self.librarian_request_producer = Producer( - backend = self.pubsub, - topic = librarian_request_q, - schema = LibrarianRequest, - metrics = librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor = id, flow = None, name = "librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = librarian_response_q, - subscriber = f"{id}-librarian", - schema = LibrarianResponse, - handler = self.on_librarian_response, - metrics = librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} logger.info("PDF decoder initialized") async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_metadata(self, document_id, user, timeout=120): - """ - Fetch document metadata from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.document_metadata - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching metadata for {document_id}") - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, timeout=120): - """ - Save a child document to the librarian. - - Args: - doc_id: ID for the new child document - parent_id: ID of the parent document - user: User ID - content: Document content (bytes) - document_type: Type of document ("page", "chunk", etc.) - title: Optional title - timeout: Request timeout in seconds - - Returns: - The document ID on success - """ - import base64 - - request_id = str(uuid.uuid4()) - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving child document: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving child document {doc_id}") + await self.librarian.start() async def on_message(self, msg, consumer, flow): @@ -266,7 +91,7 @@ class Processor(FlowProcessor): # Check MIME type if fetching from librarian if v.document_id: - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) @@ -287,7 +112,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") fp.close() - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, ) @@ -323,7 +148,7 @@ class Processor(FlowProcessor): page_content = page.page_content.encode("utf-8") # Save page as child document in librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index cdf5daba..8d1aca9e 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -10,7 +10,7 @@ import logging import os from trustgraph.base.logging import setup_logging -from trustgraph.base.pubsub import get_pubsub +from trustgraph.base.pubsub import get_pubsub, add_pubsub_args from . auth import Authenticator from . config.receiver import ConfigReceiver @@ -167,30 +167,7 @@ def run(): help='Service identifier for logging and metrics (default: api-gateway)', ) - # Pub/sub backend selection - parser.add_argument( - '--pubsub-backend', - default=os.getenv('PUBSUB_BACKEND', 'pulsar'), - choices=['pulsar', 'mqtt'], - help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '--pulsar-api-key', - default=default_pulsar_api_key, - help=f'Pulsar API key', - ) - - parser.add_argument( - '--pulsar-listener', - help=f'Pulsar listener (default: none)', - ) + add_pubsub_args(parser) parser.add_argument( '-m', '--prometheus-url', diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index b81c6321..c0e55d84 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -12,22 +12,18 @@ import uuid from ... schema import DocumentRagQuery, DocumentRagResponse, Error from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata -from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples, Metadata from ... provenance import GRAPH_RETRIEVAL from . document_rag import DocumentRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec -from ... base import Consumer, Producer -from ... base import ConsumerMetrics, ProducerMetrics +from ... base import LibrarianClient # Module logger logger = logging.getLogger(__name__) default_ident = "document-rag" -default_librarian_request_queue = librarian_request_queue -default_librarian_response_queue = librarian_response_queue class Processor(FlowProcessor): @@ -89,111 +85,26 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching chunk content from Garage - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue - ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor=id, flow=None, name="librarian-request" - ) - - self.librarian_request_producer = Producer( + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, - topic=librarian_request_q, - schema=LibrarianRequest, - metrics=librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor=id, flow=None, name="librarian-response" - ) - - self.librarian_response_consumer = Consumer( taskgroup=self.taskgroup, - backend=self.pubsub, - flow=None, - topic=librarian_response_q, - subscriber=f"{id}-librarian", - schema=LibrarianResponse, - handler=self.on_librarian_response, - metrics=librarian_response_metrics, ) - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} - async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) + await self.librarian.start() async def fetch_chunk_content(self, chunk_id, user, timeout=120): - """Fetch chunk content from librarian/Garage.""" - import uuid - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=chunk_id, - user=user, + """Fetch chunk content from librarian. Chunks are small so + single request-response is fine.""" + return await self.librarian.fetch_document_text( + document_id=chunk_id, user=user, timeout=timeout, ) - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - # Content is base64 encoded - content = response.content - if isinstance(content, str): - content = content.encode('utf-8') - return base64.b64decode(content).decode("utf-8") - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching chunk {chunk_id}") - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): - """ - Save answer content to the librarian. - - Args: - doc_id: ID for the answer document - user: User ID - content: Answer text content - title: Optional title - timeout: Request timeout in seconds - - Returns: - The document ID on success - """ - request_id = str(uuid.uuid4()) + """Save answer content to the librarian.""" doc_metadata = DocumentMetadata( id=doc_id, @@ -211,29 +122,8 @@ class Processor(FlowProcessor): user=user, ) - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving answer: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving answer document {doc_id}") + await self.librarian.request(request, timeout=timeout) + return doc_id async def on_request(self, msg, consumer, flow): @@ -390,4 +280,3 @@ class Processor(FlowProcessor): def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index dd410d90..4844b104 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -7,19 +7,15 @@ Supports both inline document data and fetching from librarian via Pulsar for large documents. """ -import asyncio import base64 import logging -import uuid import pytesseract from pdf2image import convert_from_bytes from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, derived_entity_triples, @@ -72,173 +68,16 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor = id, flow = None, name = "librarian-request" - ) - - self.librarian_request_producer = Producer( - backend = self.pubsub, - topic = librarian_request_q, - schema = LibrarianRequest, - metrics = librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor = id, flow = None, name = "librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = librarian_response_q, - subscriber = f"{id}-librarian", - schema = LibrarianResponse, - handler = self.on_librarian_response, - metrics = librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} logger.info("PDF OCR processor initialized") async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_metadata(self, document_id, user, timeout=120): - """ - Fetch document metadata from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.document_metadata - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching metadata for {document_id}") - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, timeout=120): - """ - Save a child document to the librarian. - """ - request_id = str(uuid.uuid4()) - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving child document: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving child document {doc_id}") + await self.librarian.start() async def on_message(self, msg, consumer, flow): @@ -250,7 +89,7 @@ class Processor(FlowProcessor): # Check MIME type if fetching from librarian if v.document_id: - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) @@ -265,7 +104,7 @@ class Processor(FlowProcessor): # Get PDF content - fetch from librarian or use inline data if v.document_id: logger.info(f"Fetching document {v.document_id} from librarian...") - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, ) @@ -299,7 +138,7 @@ class Processor(FlowProcessor): page_content = text.encode("utf-8") # Save page as child document in librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, user=v.metadata.user, diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py index b8d05158..6b7d0246 100644 --- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py +++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py @@ -14,22 +14,18 @@ Tables are preserved as HTML markup for better downstream extraction. Images are stored in the librarian but not sent to the text pipeline. """ -import asyncio import base64 import logging import magic import tempfile import os -import uuid from unstructured.partition.auto import partition from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, @@ -166,128 +162,16 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching/storing document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor=id, flow=None, name="librarian-request" - ) - - self.librarian_request_producer = Producer( - backend=self.pubsub, - topic=librarian_request_q, - schema=LibrarianRequest, - metrics=librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor=id, flow=None, name="librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup=self.taskgroup, - backend=self.pubsub, - flow=None, - topic=librarian_response_q, - subscriber=f"{id}-librarian", - schema=LibrarianResponse, - handler=self.on_librarian_response, - metrics=librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} logger.info("Universal decoder initialized") async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def _librarian_request(self, request, timeout=120): - """Send a request to the librarian and wait for response.""" - request_id = str(uuid.uuid4()) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: " - f"{response.error.message}" - ) - - return response - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError("Timeout waiting for librarian response") - - async def fetch_document_metadata(self, document_id, user): - """Fetch document metadata from the librarian.""" - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - response = await self._librarian_request(request) - return response.document_metadata - - async def fetch_document_content(self, document_id, user): - """Fetch document content from the librarian.""" - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - response = await self._librarian_request(request) - return response.content - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, - kind="text/plain"): - """Save a child document to the librarian.""" - if isinstance(content, str): - content = content.encode("utf-8") - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind=kind, - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - await self._librarian_request(request) - return doc_id + await self.librarian.start() def extract_elements(self, blob, mime_type=None): """ @@ -388,7 +272,7 @@ class Processor(FlowProcessor): page_content = text.encode("utf-8") # Save to librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=doc_id, parent_id=parent_doc_id, user=metadata.user, @@ -469,7 +353,7 @@ class Processor(FlowProcessor): # Save to librarian if img_content: - await self.save_child_document( + await self.librarian.save_child_document( doc_id=img_uri, parent_id=parent_doc_id, user=metadata.user, @@ -518,13 +402,13 @@ class Processor(FlowProcessor): f"Fetching document {v.document_id} from librarian..." ) - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) mime_type = doc_meta.kind if doc_meta else None - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, ) From 62c30a3a50ab77294cecf3a3c91b208e0db711ed Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 2 Apr 2026 13:20:39 +0100 Subject: [PATCH 27/37] Skip Pulsar check in tg-verify-system-status (#753) --- .../trustgraph/cli/verify_system_status.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py index 5fea1bb0..fa01898b 100644 --- a/trustgraph-cli/trustgraph/cli/verify_system_status.py +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -377,24 +377,14 @@ def main(): print("=" * 60) print("TrustGraph System Status Verification") print("=" * 60) -# print(f"Global timeout: {args.global_timeout}s") -# print(f"Check timeout: {args.check_timeout}s") -# print(f"Retry delay: {args.retry_delay}s") -# print("=" * 60) print() # Phase 1: Infrastructure print("Phase 1: Infrastructure") print("-" * 60) - if not checker.run_check( - "Pulsar", - check_pulsar, - args.pulsar_url, - args.check_timeout - ): - print("\n⚠️ Pulsar is not responding - other checks may fail") - print() + # Pulsar check is skipped — not all deployments use Pulsar. + # The API Gateway check covers broker connectivity indirectly. checker.run_check( "API Gateway", From d9dc4cbab5d2e7e32cc20e1bd68ab4179517a628 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 2 Apr 2026 17:21:39 +0100 Subject: [PATCH 28/37] SPARQL query service (#754) SPARQL 1.1 query service wrapping pub/sub triples interface Add a backend-agnostic SPARQL query service that parses SPARQL queries using rdflib, decomposes them into triple pattern lookups via the existing TriplesClient pub/sub interface, and performs in-memory joins, filters, and projections. Includes: - SPARQL parser, algebra evaluator, expression evaluator, solution sequence operations (BGP, JOIN, OPTIONAL, UNION, FILTER, BIND, VALUES, GROUP BY, ORDER BY, LIMIT/OFFSET, DISTINCT, aggregates) - FlowProcessor service with TriplesClientSpec - Gateway dispatcher, request/response translators, API spec - Python SDK method (FlowInstance.sparql_query) - CLI command (tg-invoke-sparql-query) - Tech spec (docs/tech-specs/sparql-query.md) New unit tests for SPARQL query --- Makefile | 4 +- docs/tech-specs/sparql-query.md | 268 +++++++++ specs/api/paths/flow/sparql-query.yaml | 145 +++++ .../test_query/test_sparql_expressions.py | 424 ++++++++++++++ tests/unit/test_query/test_sparql_parser.py | 205 +++++++ .../unit/test_query/test_sparql_solutions.py | 345 +++++++++++ trustgraph-base/trustgraph/api/flow.py | 39 ++ .../trustgraph/messaging/__init__.py | 7 + .../messaging/translators/sparql_query.py | 111 ++++ .../trustgraph/schema/services/__init__.py | 3 +- .../schema/services/sparql_query.py | 40 ++ trustgraph-cli/pyproject.toml | 1 + .../trustgraph/cli/invoke_sparql_query.py | 230 ++++++++ trustgraph-flow/pyproject.toml | 1 + .../trustgraph/gateway/dispatch/manager.py | 2 + .../gateway/dispatch/sparql_query.py | 30 + .../trustgraph/query/sparql/__init__.py | 1 + .../trustgraph/query/sparql/__main__.py | 6 + .../trustgraph/query/sparql/algebra.py | 541 ++++++++++++++++++ .../trustgraph/query/sparql/expressions.py | 481 ++++++++++++++++ .../trustgraph/query/sparql/parser.py | 139 +++++ .../trustgraph/query/sparql/service.py | 230 ++++++++ .../trustgraph/query/sparql/solutions.py | 248 ++++++++ 23 files changed, 3498 insertions(+), 3 deletions(-) create mode 100644 docs/tech-specs/sparql-query.md create mode 100644 specs/api/paths/flow/sparql-query.yaml create mode 100644 tests/unit/test_query/test_sparql_expressions.py create mode 100644 tests/unit/test_query/test_sparql_parser.py create mode 100644 tests/unit/test_query/test_sparql_solutions.py create mode 100644 trustgraph-base/trustgraph/messaging/translators/sparql_query.py create mode 100644 trustgraph-base/trustgraph/schema/services/sparql_query.py create mode 100644 trustgraph-cli/trustgraph/cli/invoke_sparql_query.py create mode 100644 trustgraph-flow/trustgraph/gateway/dispatch/sparql_query.py create mode 100644 trustgraph-flow/trustgraph/query/sparql/__init__.py create mode 100644 trustgraph-flow/trustgraph/query/sparql/__main__.py create mode 100644 trustgraph-flow/trustgraph/query/sparql/algebra.py create mode 100644 trustgraph-flow/trustgraph/query/sparql/expressions.py create mode 100644 trustgraph-flow/trustgraph/query/sparql/parser.py create mode 100644 trustgraph-flow/trustgraph/query/sparql/service.py create mode 100644 trustgraph-flow/trustgraph/query/sparql/solutions.py diff --git a/Makefile b/Makefile index 4d79f554..197a6c63 100644 --- a/Makefile +++ b/Makefile @@ -77,8 +77,8 @@ some-containers: -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . - ${DOCKER} build -f containers/Containerfile.unstructured \ - -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . +# ${DOCKER} build -f containers/Containerfile.unstructured \ +# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . # ${DOCKER} build -f containers/Containerfile.vertexai \ # -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . # ${DOCKER} build -f containers/Containerfile.mcp \ diff --git a/docs/tech-specs/sparql-query.md b/docs/tech-specs/sparql-query.md new file mode 100644 index 00000000..97e7e115 --- /dev/null +++ b/docs/tech-specs/sparql-query.md @@ -0,0 +1,268 @@ +# SPARQL Query Service Technical Specification + +## Overview + +A pub/sub-hosted SPARQL query service that accepts SPARQL queries, decomposes +them into triple pattern lookups via the existing triples query pub/sub +interface, performs in-memory joins/filters/projections, and returns SPARQL +result bindings. + +This makes the triple store queryable using a standard graph query language +without coupling to any specific backend (Neo4j, Cassandra, FalkorDB, etc.). + +## Goals + +- **SPARQL 1.1 support**: SELECT, ASK, CONSTRUCT, DESCRIBE queries +- **Backend-agnostic**: query via the pub/sub triples interface, not direct + database access +- **Standard service pattern**: FlowProcessor with ConsumerSpec/ProducerSpec, + using TriplesClientSpec to call the triples query service +- **Correct SPARQL semantics**: proper BGP evaluation, joins, OPTIONAL, UNION, + FILTER, BIND, aggregation, solution modifiers (ORDER BY, LIMIT, OFFSET, + DISTINCT) + +## Background + +The triples query service provides a single-pattern lookup: given optional +(s, p, o) values, return matching triples. This is the equivalent of one +triple pattern in a SPARQL Basic Graph Pattern. + +To evaluate a full SPARQL query, we need to: +1. Parse the SPARQL string into an algebra tree +2. Walk the algebra tree, issuing triple pattern lookups for each BGP pattern +3. Join results across patterns (nested-loop or hash join) +4. Apply filters, optionals, unions, and aggregations in-memory +5. Project and return the requested variables + +rdflib (already a dependency) provides a SPARQL 1.1 parser and algebra +compiler. We use rdflib to parse queries into algebra trees, then evaluate +the algebra ourselves using the triples query client as the data source. + +## Technical Design + +### Architecture + +``` + pub/sub + [Client] ──request──> [SPARQL Query Service] ──triples-request──> [Triples Query Service] + [Client] <─response── [SPARQL Query Service] <─triples-response── [Triples Query Service] +``` + +The service is a FlowProcessor that: +- Consumes SPARQL query requests +- Uses TriplesClientSpec to issue triple pattern lookups +- Evaluates the SPARQL algebra in-memory +- Produces result responses + +### Components + +1. **SPARQL Query Service (FlowProcessor)** + - ConsumerSpec for incoming SPARQL requests + - ProducerSpec for outgoing results + - TriplesClientSpec for calling the triples query service + - Delegates parsing and evaluation to the components below + + Module: `trustgraph-flow/trustgraph/query/sparql/service.py` + +2. **SPARQL Parser (rdflib wrapper)** + - Uses `rdflib.plugins.sparql.prepareQuery` / `parseQuery` and + `rdflib.plugins.sparql.algebra.translateQuery` to produce an algebra tree + - Extracts PREFIX declarations, query type (SELECT/ASK/CONSTRUCT/DESCRIBE), + and the algebra root + + Module: `trustgraph-flow/trustgraph/query/sparql/parser.py` + +3. **Algebra Evaluator** + - Recursive evaluator over the rdflib algebra tree + - Each algebra node type maps to an evaluation function + - BGP nodes issue triple pattern queries via TriplesClient + - Join/Filter/Optional/Union etc. operate on in-memory solution sequences + + Module: `trustgraph-flow/trustgraph/query/sparql/algebra.py` + +4. **Solution Sequence** + - A solution is a dict mapping variable names to Term values + - Solution sequences are lists of solutions + - Join: hash join on shared variables + - LeftJoin (OPTIONAL): hash join preserving unmatched left rows + - Union: concatenation + - Filter: evaluate SPARQL expressions against each solution + - Projection/Distinct/Order/Slice: standard post-processing + + Module: `trustgraph-flow/trustgraph/query/sparql/solutions.py` + +### Data Models + +#### Request + +```python +@dataclass +class SparqlQueryRequest: + user: str = "" + collection: str = "" + query: str = "" # SPARQL query string + limit: int = 10000 # Safety limit on results +``` + +#### Response + +```python +@dataclass +class SparqlQueryResponse: + error: Error | None = None + query_type: str = "" # "select", "ask", "construct", "describe" + + # For SELECT queries + variables: list[str] = field(default_factory=list) + bindings: list[SparqlBinding] = field(default_factory=list) + + # For ASK queries + ask_result: bool = False + + # For CONSTRUCT/DESCRIBE queries + triples: list[Triple] = field(default_factory=list) + +@dataclass +class SparqlBinding: + values: list[Term | None] = field(default_factory=list) +``` + +### BGP Evaluation Strategy + +For each triple pattern in a BGP: +- Extract bound terms (concrete IRIs/literals) and variables +- Call `TriplesClient.query_stream(s, p, o)` with bound terms, None for + variables +- Map returned triples back to variable bindings + +For multi-pattern BGPs, join solutions incrementally: +- Order patterns by selectivity (patterns with more bound terms first) +- For each subsequent pattern, substitute bound variables from the current + solution sequence before querying +- This avoids full cross-products and reduces the number of triples queries + +### Streaming and Early Termination + +The triples query service supports streaming responses (batched delivery via +`TriplesClient.query_stream`). The SPARQL evaluator should use streaming +from the start, not as an optimisation. This is important because: + +- **Early termination**: when the SPARQL query has a LIMIT, or when only one + solution is needed (ASK queries), we can stop consuming triples as soon as + we have enough results. Without streaming, a wildcard pattern like + `?s ?p ?o` would fetch the entire graph before we could apply the limit. +- **Memory efficiency**: results are processed batch-by-batch rather than + materialising the full result set in memory before joining. + +The batch callback in `query_stream` returns a boolean to signal completion. +The evaluator should signal completion (return True) as soon as sufficient +solutions have been produced, allowing the underlying pub/sub consumer to +stop pulling batches. + +### Parallel BGP Execution (Phase 2 Optimisation) + +Within a BGP, patterns that share variables benefit from sequential +evaluation with bound-variable substitution (query results from earlier +patterns narrow later queries). However, patterns with no shared variables +are independent and could be issued concurrently via `asyncio.gather`. + +A practical approach for a future optimisation pass: +- Analyse BGP patterns and identify connected components (groups of + patterns linked by shared variables) +- Execute independent components in parallel +- Within each component, evaluate patterns sequentially with substitution + +This is not needed for correctness -- the sequential approach works for all +cases -- but could significantly reduce latency for queries with independent +pattern groups. Flagged as a phase 2 optimisation. + +### FILTER Expression Evaluation + +rdflib's algebra represents FILTER expressions as expression trees. We +evaluate these against each solution row, supporting: +- Comparison operators (=, !=, <, >, <=, >=) +- Logical operators (&&, ||, !) +- SPARQL built-in functions (isIRI, isLiteral, isBlank, str, lang, + datatype, bound, regex, etc.) +- Arithmetic operators (+, -, *, /) + +## Implementation Order + +1. **Schema and service skeleton** -- define SparqlQueryRequest/Response + dataclasses, create the FlowProcessor subclass with ConsumerSpec, + ProducerSpec, and TriplesClientSpec wired up. Verify it starts and + connects. + +2. **SPARQL parsing** -- wrap rdflib's parser to produce algebra trees from + SPARQL strings. Handle parse errors gracefully. Unit test with a range of + query shapes. + +3. **BGP evaluation** -- implement single-pattern and multi-pattern BGP + evaluation using TriplesClient. This is the core building block. Test + with simple SELECT WHERE { ?s ?p ?o } queries. + +4. **Joins and solution sequences** -- implement hash join, left join (for + OPTIONAL), and union. Test with multi-pattern queries. + +5. **FILTER evaluation** -- implement the expression evaluator for FILTER + clauses. Start with comparisons and logical operators, then add built-in + functions incrementally. + +6. **Solution modifiers** -- DISTINCT, ORDER BY, LIMIT, OFFSET, projection. + +7. **ASK / CONSTRUCT / DESCRIBE** -- extend beyond SELECT. ASK is trivial + (non-empty result = true). CONSTRUCT builds triples from a template. + DESCRIBE fetches all triples for matched resources. + +8. **Aggregation** -- GROUP BY, HAVING, COUNT, SUM, AVG, MIN, MAX, + GROUP_CONCAT, SAMPLE. + +9. **BIND, VALUES, subqueries** -- remaining SPARQL 1.1 features. + +10. **API gateway integration** -- add SparqlQueryRequestor dispatcher, + request/response translators, and API endpoint so that the SPARQL + service is accessible via the HTTP gateway. + +11. **SDK support** -- add `sparql_query()` method to FlowInstance in the + Python API SDK, following the same pattern as `triples_query()`. + +12. **CLI command** -- add a `tg-sparql-query` CLI command that takes a + SPARQL query string (or reads from a file/stdin), submits it via the + SDK, and prints results in a readable format (table for SELECT, + true/false for ASK, Turtle for CONSTRUCT/DESCRIBE). + +## Performance Considerations + +In-memory join over pub/sub round-trips will be slower than native SPARQL on +a graph database. Key mitigations: + +- **Streaming with early termination**: use `query_stream` so that + limit-bound queries don't fetch entire result sets. A `SELECT ... LIMIT 1` + against a wildcard pattern fetches one batch, not the whole graph. +- **Bound-variable substitution**: when evaluating BGP patterns sequentially, + substitute known bindings into subsequent patterns to issue narrow queries + rather than broad ones followed by in-memory filtering. +- **Parallel independent patterns** (phase 2): patterns with no shared + variables can be issued concurrently. +- **Query complexity limits**: may need a cap on the number of triple pattern + queries issued per SPARQL query to prevent runaway evaluation. + +### Named Graph Mapping + +SPARQL's `GRAPH ?g { ... }` and `GRAPH { ... }` clauses map to the +triples query service's graph filter parameter: + +- `GRAPH { ?s ?p ?o }` — pass `g=uri` to the triples query +- Patterns outside any GRAPH clause — pass `g=""` (default graph only) +- `GRAPH ?g { ?s ?p ?o }` — pass `g="*"` (all graphs), then bind `?g` from + the returned triple's graph field + +The triples query interface does not support a wildcard graph natively in +the SPARQL sense, but `g="*"` (all graphs) combined with client-side +filtering on the returned graph values achieves the same effect. + +## Open Questions + +- **SPARQL 1.2**: rdflib's parser support for 1.2 features (property paths + are already in 1.1; 1.2 adds lateral joins, ADJUST, etc.). Start with + 1.1 and extend as rdflib support matures. diff --git a/specs/api/paths/flow/sparql-query.yaml b/specs/api/paths/flow/sparql-query.yaml new file mode 100644 index 00000000..2f343488 --- /dev/null +++ b/specs/api/paths/flow/sparql-query.yaml @@ -0,0 +1,145 @@ +post: + tags: + - Flow Services + summary: SPARQL query - execute SPARQL 1.1 queries against the knowledge graph + description: | + Execute a SPARQL 1.1 query against the knowledge graph. + + ## Supported Query Types + + - **SELECT**: Returns variable bindings as a table of results + - **ASK**: Returns true/false for existence checks + - **CONSTRUCT**: Returns a set of triples built from a template + - **DESCRIBE**: Returns triples describing matched resources + + ## SPARQL Features + + Supports standard SPARQL 1.1 features including: + - Basic Graph Patterns (BGPs) with triple pattern matching + - OPTIONAL, UNION, FILTER + - BIND, VALUES + - ORDER BY, LIMIT, OFFSET, DISTINCT + - GROUP BY with aggregates (COUNT, SUM, AVG, MIN, MAX, GROUP_CONCAT) + - Built-in functions (isIRI, STR, REGEX, CONTAINS, etc.) + + ## Query Examples + + Find all entities of a type: + ```sparql + SELECT ?s ?label WHERE { + ?s . + ?s ?label . + } + LIMIT 10 + ``` + + Check if an entity exists: + ```sparql + ASK { ?p ?o } + ``` + + operationId: sparqlQueryService + security: + - bearerAuth: [] + parameters: + - name: flow + in: path + required: true + schema: + type: string + description: Flow instance ID + example: my-flow + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - query + properties: + query: + type: string + description: SPARQL 1.1 query string + user: + type: string + default: trustgraph + description: User/keyspace identifier + collection: + type: string + default: default + description: Collection identifier + limit: + type: integer + default: 10000 + description: Safety limit on number of results + examples: + selectQuery: + summary: SELECT query + value: + query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" + user: trustgraph + collection: default + askQuery: + summary: ASK query + value: + query: "ASK { ?p ?o }" + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + properties: + query-type: + type: string + enum: [select, ask, construct, describe] + variables: + type: array + items: + type: string + description: Variable names (SELECT only) + bindings: + type: array + items: + type: object + properties: + values: + type: array + items: + $ref: '../../components/schemas/common/RdfValue.yaml' + description: Result rows (SELECT only) + ask-result: + type: boolean + description: Boolean result (ASK only) + triples: + type: array + description: Result triples (CONSTRUCT/DESCRIBE only) + error: + type: object + properties: + type: + type: string + message: + type: string + examples: + selectResult: + summary: SELECT result + value: + query-type: select + variables: [s, p, o] + bindings: + - values: + - {t: i, i: "http://example.com/alice"} + - {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"} + - {t: i, i: "http://example.com/Person"} + askResult: + summary: ASK result + value: + query-type: ask + ask-result: true + '401': + $ref: '../../components/responses/Unauthorized.yaml' + '500': + $ref: '../../components/responses/Error.yaml' diff --git a/tests/unit/test_query/test_sparql_expressions.py b/tests/unit/test_query/test_sparql_expressions.py new file mode 100644 index 00000000..63e9188f --- /dev/null +++ b/tests/unit/test_query/test_sparql_expressions.py @@ -0,0 +1,424 @@ +""" +Tests for SPARQL FILTER expression evaluator. +""" + +import pytest +from trustgraph.schema import Term, IRI, LITERAL, BLANK +from trustgraph.query.sparql.expressions import ( + evaluate_expression, _effective_boolean, _to_string, _to_numeric, + _comparable_value, +) + + +# --- Helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + +def lit(v, datatype="", language=""): + return Term(type=LITERAL, value=v, datatype=datatype, language=language) + +def blank(v): + return Term(type=BLANK, id=v) + +XSD = "http://www.w3.org/2001/XMLSchema#" + + +class TestEvaluateExpression: + """Test expression evaluation with rdflib algebra nodes.""" + + def test_variable_bound(self): + from rdflib.term import Variable + result = evaluate_expression(Variable("x"), {"x": lit("hello")}) + assert result.value == "hello" + + def test_variable_unbound(self): + from rdflib.term import Variable + result = evaluate_expression(Variable("x"), {}) + assert result is None + + def test_uriref_constant(self): + from rdflib import URIRef + result = evaluate_expression( + URIRef("http://example.com/a"), {} + ) + assert result.type == IRI + assert result.iri == "http://example.com/a" + + def test_literal_constant(self): + from rdflib import Literal + result = evaluate_expression(Literal("hello"), {}) + assert result.type == LITERAL + assert result.value == "hello" + + def test_boolean_constant(self): + assert evaluate_expression(True, {}) is True + assert evaluate_expression(False, {}) is False + + def test_numeric_constant(self): + assert evaluate_expression(42, {}) == 42 + assert evaluate_expression(3.14, {}) == 3.14 + + def test_none_returns_true(self): + assert evaluate_expression(None, {}) is True + + +class TestRelationalExpressions: + """Test comparison operators via CompValue nodes.""" + + def _make_relational(self, left, op, right): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("RelationalExpression", + expr=left, op=op, other=right) + + def test_equal_literals(self): + from rdflib import Literal + expr = self._make_relational(Literal("a"), "=", Literal("a")) + assert evaluate_expression(expr, {}) is True + + def test_not_equal_literals(self): + from rdflib import Literal + expr = self._make_relational(Literal("a"), "!=", Literal("b")) + assert evaluate_expression(expr, {}) is True + + def test_less_than(self): + from rdflib import Literal + expr = self._make_relational(Literal("a"), "<", Literal("b")) + assert evaluate_expression(expr, {}) is True + + def test_greater_than(self): + from rdflib import Literal + expr = self._make_relational(Literal("b"), ">", Literal("a")) + assert evaluate_expression(expr, {}) is True + + def test_equal_with_variables(self): + from rdflib.term import Variable + expr = self._make_relational(Variable("x"), "=", Variable("y")) + sol = {"x": lit("same"), "y": lit("same")} + assert evaluate_expression(expr, sol) is True + + def test_unequal_with_variables(self): + from rdflib.term import Variable + expr = self._make_relational(Variable("x"), "=", Variable("y")) + sol = {"x": lit("one"), "y": lit("two")} + assert evaluate_expression(expr, sol) is False + + def test_none_operand_returns_false(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_relational(Variable("x"), "=", Literal("a")) + assert evaluate_expression(expr, {}) is False + + +class TestLogicalExpressions: + + def _make_and(self, exprs): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("ConditionalAndExpression", + expr=exprs[0], other=exprs[1:]) + + def _make_or(self, exprs): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("ConditionalOrExpression", + expr=exprs[0], other=exprs[1:]) + + def _make_not(self, expr): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("UnaryNot", expr=expr) + + def test_and_true_true(self): + result = evaluate_expression(self._make_and([True, True]), {}) + assert result is True + + def test_and_true_false(self): + result = evaluate_expression(self._make_and([True, False]), {}) + assert result is False + + def test_or_false_true(self): + result = evaluate_expression(self._make_or([False, True]), {}) + assert result is True + + def test_or_false_false(self): + result = evaluate_expression(self._make_or([False, False]), {}) + assert result is False + + def test_not_true(self): + result = evaluate_expression(self._make_not(True), {}) + assert result is False + + def test_not_false(self): + result = evaluate_expression(self._make_not(False), {}) + assert result is True + + +class TestBuiltinFunctions: + + def _make_builtin(self, name, **kwargs): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue(f"Builtin_{name}", **kwargs) + + def test_bound_true(self): + from rdflib.term import Variable + expr = self._make_builtin("BOUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("hi")}) is True + + def test_bound_false(self): + from rdflib.term import Variable + expr = self._make_builtin("BOUND", arg=Variable("x")) + assert evaluate_expression(expr, {}) is False + + def test_isiri_true(self): + from rdflib.term import Variable + expr = self._make_builtin("isIRI", arg=Variable("x")) + assert evaluate_expression(expr, {"x": iri("http://x")}) is True + + def test_isiri_false(self): + from rdflib.term import Variable + expr = self._make_builtin("isIRI", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("hello")}) is False + + def test_isliteral_true(self): + from rdflib.term import Variable + expr = self._make_builtin("isLITERAL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_isliteral_false(self): + from rdflib.term import Variable + expr = self._make_builtin("isLITERAL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": iri("http://x")}) is False + + def test_isblank_true(self): + from rdflib.term import Variable + expr = self._make_builtin("isBLANK", arg=Variable("x")) + assert evaluate_expression(expr, {"x": blank("b1")}) is True + + def test_isblank_false(self): + from rdflib.term import Variable + expr = self._make_builtin("isBLANK", arg=Variable("x")) + assert evaluate_expression(expr, {"x": iri("http://x")}) is False + + def test_str(self): + from rdflib.term import Variable + expr = self._make_builtin("STR", arg=Variable("x")) + result = evaluate_expression(expr, {"x": iri("http://example.com/a")}) + assert result.type == LITERAL + assert result.value == "http://example.com/a" + + def test_lang(self): + from rdflib.term import Variable + expr = self._make_builtin("LANG", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("hello", language="en")} + ) + assert result.value == "en" + + def test_lang_no_tag(self): + from rdflib.term import Variable + expr = self._make_builtin("LANG", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_datatype(self): + from rdflib.term import Variable + expr = self._make_builtin("DATATYPE", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("42", datatype=XSD + "integer")} + ) + assert result.type == IRI + assert result.iri == XSD + "integer" + + def test_strlen(self): + from rdflib.term import Variable + expr = self._make_builtin("STRLEN", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result == 5 + + def test_ucase(self): + from rdflib.term import Variable + expr = self._make_builtin("UCASE", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "HELLO" + + def test_lcase(self): + from rdflib.term import Variable + expr = self._make_builtin("LCASE", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("HELLO")}) + assert result.value == "hello" + + def test_contains_true(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("CONTAINS", + arg1=Variable("x"), arg2=Literal("ell")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_contains_false(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("CONTAINS", + arg1=Variable("x"), arg2=Literal("xyz")) + assert evaluate_expression(expr, {"x": lit("hello")}) is False + + def test_strstarts_true(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRSTARTS", + arg1=Variable("x"), arg2=Literal("hel")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_strends_true(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRENDS", + arg1=Variable("x"), arg2=Literal("llo")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_regex_match(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REGEX", + text=Variable("x"), + pattern=Literal("^hel"), + flags=None) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_regex_case_insensitive(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REGEX", + text=Variable("x"), + pattern=Literal("HELLO"), + flags=Literal("i")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_regex_no_match(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REGEX", + text=Variable("x"), + pattern=Literal("^world"), + flags=None) + assert evaluate_expression(expr, {"x": lit("hello")}) is False + + +class TestEffectiveBoolean: + + def test_true(self): + assert _effective_boolean(True) is True + + def test_false(self): + assert _effective_boolean(False) is False + + def test_none(self): + assert _effective_boolean(None) is False + + def test_nonzero_int(self): + assert _effective_boolean(42) is True + + def test_zero_int(self): + assert _effective_boolean(0) is False + + def test_nonempty_string(self): + assert _effective_boolean("hello") is True + + def test_empty_string(self): + assert _effective_boolean("") is False + + def test_iri_term(self): + assert _effective_boolean(iri("http://x")) is True + + def test_nonempty_literal(self): + assert _effective_boolean(lit("hello")) is True + + def test_empty_literal(self): + assert _effective_boolean(lit("")) is False + + def test_boolean_literal_true(self): + assert _effective_boolean( + lit("true", datatype=XSD + "boolean") + ) is True + + def test_boolean_literal_false(self): + assert _effective_boolean( + lit("false", datatype=XSD + "boolean") + ) is False + + def test_numeric_literal_nonzero(self): + assert _effective_boolean( + lit("42", datatype=XSD + "integer") + ) is True + + def test_numeric_literal_zero(self): + assert _effective_boolean( + lit("0", datatype=XSD + "integer") + ) is False + + +class TestToString: + + def test_none(self): + assert _to_string(None) == "" + + def test_string(self): + assert _to_string("hello") == "hello" + + def test_iri_term(self): + assert _to_string(iri("http://example.com")) == "http://example.com" + + def test_literal_term(self): + assert _to_string(lit("hello")) == "hello" + + def test_blank_term(self): + assert _to_string(blank("b1")) == "b1" + + +class TestToNumeric: + + def test_none(self): + assert _to_numeric(None) is None + + def test_int(self): + assert _to_numeric(42) == 42 + + def test_float(self): + assert _to_numeric(3.14) == 3.14 + + def test_integer_literal(self): + assert _to_numeric(lit("42")) == 42 + + def test_decimal_literal(self): + assert _to_numeric(lit("3.14")) == 3.14 + + def test_non_numeric_literal(self): + assert _to_numeric(lit("hello")) is None + + def test_numeric_string(self): + assert _to_numeric("42") == 42 + + def test_non_numeric_string(self): + assert _to_numeric("abc") is None + + +class TestComparableValue: + + def test_none(self): + assert _comparable_value(None) == (0, "") + + def test_int(self): + assert _comparable_value(42) == (2, 42) + + def test_iri(self): + assert _comparable_value(iri("http://x")) == (4, "http://x") + + def test_literal(self): + assert _comparable_value(lit("hello")) == (3, "hello") + + def test_numeric_literal(self): + assert _comparable_value(lit("42")) == (2, 42) + + def test_ordering(self): + vals = [lit("b"), lit("a"), lit("c")] + sorted_vals = sorted(vals, key=_comparable_value) + assert sorted_vals[0].value == "a" + assert sorted_vals[1].value == "b" + assert sorted_vals[2].value == "c" diff --git a/tests/unit/test_query/test_sparql_parser.py b/tests/unit/test_query/test_sparql_parser.py new file mode 100644 index 00000000..5ac9fad9 --- /dev/null +++ b/tests/unit/test_query/test_sparql_parser.py @@ -0,0 +1,205 @@ +""" +Tests for the SPARQL parser module. +""" + +import pytest +from trustgraph.query.sparql.parser import ( + parse_sparql, ParseError, rdflib_term_to_term, term_to_rdflib, +) +from trustgraph.schema import Term, IRI, LITERAL, BLANK + + +class TestParseSparql: + """Tests for parse_sparql function.""" + + def test_select_query_type(self): + parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }") + assert parsed.query_type == "select" + + def test_select_variables(self): + parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }") + assert parsed.variables == ["s", "p", "o"] + + def test_select_subset_variables(self): + parsed = parse_sparql("SELECT ?s ?o WHERE { ?s ?p ?o }") + assert parsed.variables == ["s", "o"] + + def test_ask_query_type(self): + parsed = parse_sparql( + "ASK { ?p ?o }" + ) + assert parsed.query_type == "ask" + + def test_ask_no_variables(self): + parsed = parse_sparql( + "ASK { ?p ?o }" + ) + assert parsed.variables == [] + + def test_construct_query_type(self): + parsed = parse_sparql( + "CONSTRUCT { ?s ?o } " + "WHERE { ?s ?o }" + ) + assert parsed.query_type == "construct" + + def test_describe_query_type(self): + parsed = parse_sparql( + "DESCRIBE " + ) + assert parsed.query_type == "describe" + + def test_select_with_limit(self): + parsed = parse_sparql( + "SELECT ?s WHERE { ?s ?p ?o } LIMIT 10" + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s"] + + def test_select_with_distinct(self): + parsed = parse_sparql( + "SELECT DISTINCT ?s WHERE { ?s ?p ?o }" + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s"] + + def test_select_with_filter(self): + parsed = parse_sparql( + 'SELECT ?s ?label WHERE { ' + ' ?s ?label . ' + ' FILTER(CONTAINS(STR(?label), "test")) ' + '}' + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s", "label"] + + def test_select_with_optional(self): + parsed = parse_sparql( + "SELECT ?s ?p ?o ?label WHERE { " + " ?s ?p ?o . " + " OPTIONAL { ?s ?label } " + "}" + ) + assert parsed.query_type == "select" + assert set(parsed.variables) == {"s", "p", "o", "label"} + + def test_select_with_union(self): + parsed = parse_sparql( + "SELECT ?s ?label WHERE { " + " { ?s ?label } " + " UNION " + " { ?s ?label } " + "}" + ) + assert parsed.query_type == "select" + + def test_select_with_order_by(self): + parsed = parse_sparql( + "SELECT ?s ?label WHERE { ?s ?label } " + "ORDER BY ?label" + ) + assert parsed.query_type == "select" + + def test_select_with_group_by(self): + parsed = parse_sparql( + "SELECT ?p (COUNT(?o) AS ?count) WHERE { ?s ?p ?o } " + "GROUP BY ?p ORDER BY DESC(?count)" + ) + assert parsed.query_type == "select" + + def test_select_with_prefixes(self): + parsed = parse_sparql( + "PREFIX rdfs: " + "SELECT ?s ?label WHERE { ?s rdfs:label ?label }" + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s", "label"] + + def test_algebra_not_none(self): + parsed = parse_sparql("SELECT ?s WHERE { ?s ?p ?o }") + assert parsed.algebra is not None + + def test_parse_error_invalid_sparql(self): + with pytest.raises(ParseError): + parse_sparql("NOT VALID SPARQL AT ALL") + + def test_parse_error_incomplete_query(self): + with pytest.raises(ParseError): + parse_sparql("SELECT ?s WHERE {") + + def test_parse_error_message(self): + with pytest.raises(ParseError, match="SPARQL parse error"): + parse_sparql("GIBBERISH") + + +class TestRdflibTermToTerm: + """Tests for rdflib-to-Term conversion.""" + + def test_uriref_to_term(self): + from rdflib import URIRef + term = rdflib_term_to_term(URIRef("http://example.com/alice")) + assert term.type == IRI + assert term.iri == "http://example.com/alice" + + def test_literal_to_term(self): + from rdflib import Literal + term = rdflib_term_to_term(Literal("hello")) + assert term.type == LITERAL + assert term.value == "hello" + + def test_typed_literal_to_term(self): + from rdflib import Literal, URIRef + term = rdflib_term_to_term( + Literal("42", datatype=URIRef("http://www.w3.org/2001/XMLSchema#integer")) + ) + assert term.type == LITERAL + assert term.value == "42" + assert term.datatype == "http://www.w3.org/2001/XMLSchema#integer" + + def test_lang_literal_to_term(self): + from rdflib import Literal + term = rdflib_term_to_term(Literal("hello", lang="en")) + assert term.type == LITERAL + assert term.value == "hello" + assert term.language == "en" + + def test_bnode_to_term(self): + from rdflib import BNode + term = rdflib_term_to_term(BNode("b1")) + assert term.type == BLANK + assert term.id == "b1" + + +class TestTermToRdflib: + """Tests for Term-to-rdflib conversion.""" + + def test_iri_term_to_uriref(self): + from rdflib import URIRef + result = term_to_rdflib(Term(type=IRI, iri="http://example.com/x")) + assert isinstance(result, URIRef) + assert str(result) == "http://example.com/x" + + def test_literal_term_to_literal(self): + from rdflib import Literal + result = term_to_rdflib(Term(type=LITERAL, value="hello")) + assert isinstance(result, Literal) + assert str(result) == "hello" + + def test_typed_literal_roundtrip(self): + from rdflib import URIRef + original = Term( + type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer" + ) + rdflib_term = term_to_rdflib(original) + assert rdflib_term.datatype == URIRef("http://www.w3.org/2001/XMLSchema#integer") + + def test_lang_literal_roundtrip(self): + original = Term(type=LITERAL, value="bonjour", language="fr") + rdflib_term = term_to_rdflib(original) + assert rdflib_term.language == "fr" + + def test_blank_term_to_bnode(self): + from rdflib import BNode + result = term_to_rdflib(Term(type=BLANK, id="b1")) + assert isinstance(result, BNode) diff --git a/tests/unit/test_query/test_sparql_solutions.py b/tests/unit/test_query/test_sparql_solutions.py new file mode 100644 index 00000000..5805ca84 --- /dev/null +++ b/tests/unit/test_query/test_sparql_solutions.py @@ -0,0 +1,345 @@ +""" +Tests for SPARQL solution sequence operations. +""" + +import pytest +from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.query.sparql.solutions import ( + hash_join, left_join, union, project, distinct, + order_by, slice_solutions, _terms_equal, _compatible, +) + + +# --- Test helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + +def lit(v): + return Term(type=LITERAL, value=v) + + +# --- Fixtures --- + +@pytest.fixture +def alice(): + return iri("http://example.com/alice") + +@pytest.fixture +def bob(): + return iri("http://example.com/bob") + +@pytest.fixture +def carol(): + return iri("http://example.com/carol") + +@pytest.fixture +def knows(): + return iri("http://example.com/knows") + +@pytest.fixture +def name_alice(): + return lit("Alice") + +@pytest.fixture +def name_bob(): + return lit("Bob") + + +class TestTermsEqual: + + def test_equal_iris(self): + assert _terms_equal(iri("http://x.com/a"), iri("http://x.com/a")) + + def test_unequal_iris(self): + assert not _terms_equal(iri("http://x.com/a"), iri("http://x.com/b")) + + def test_equal_literals(self): + assert _terms_equal(lit("hello"), lit("hello")) + + def test_unequal_literals(self): + assert not _terms_equal(lit("hello"), lit("world")) + + def test_iri_vs_literal(self): + assert not _terms_equal(iri("hello"), lit("hello")) + + def test_none_none(self): + assert _terms_equal(None, None) + + def test_none_vs_term(self): + assert not _terms_equal(None, iri("http://x.com/a")) + + +class TestCompatible: + + def test_no_shared_variables(self): + assert _compatible({"a": iri("http://x")}, {"b": iri("http://y")}) + + def test_shared_variable_same_value(self, alice): + assert _compatible({"s": alice, "x": lit("1")}, {"s": alice, "y": lit("2")}) + + def test_shared_variable_different_value(self, alice, bob): + assert not _compatible({"s": alice}, {"s": bob}) + + def test_empty_solutions(self): + assert _compatible({}, {}) + + def test_empty_vs_nonempty(self, alice): + assert _compatible({}, {"s": alice}) + + +class TestHashJoin: + + def test_join_on_shared_variable(self, alice, bob, name_alice, name_bob): + left = [ + {"s": alice, "p": iri("http://example.com/knows"), "o": bob}, + {"s": bob, "p": iri("http://example.com/knows"), "o": alice}, + ] + right = [ + {"s": alice, "label": name_alice}, + {"s": bob, "label": name_bob}, + ] + result = hash_join(left, right) + assert len(result) == 2 + # Check that joined solutions have all variables + for sol in result: + assert "s" in sol + assert "p" in sol + assert "o" in sol + assert "label" in sol + + def test_join_no_shared_variables_cross_product(self, alice, bob): + left = [{"a": alice}] + right = [{"b": bob}, {"b": alice}] + result = hash_join(left, right) + assert len(result) == 2 + + def test_join_no_matches(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob}] + result = hash_join(left, right) + assert len(result) == 0 + + def test_join_empty_left(self, alice): + result = hash_join([], [{"s": alice}]) + assert len(result) == 0 + + def test_join_empty_right(self, alice): + result = hash_join([{"s": alice}], []) + assert len(result) == 0 + + def test_join_multiple_matches(self, alice, name_alice): + left = [ + {"s": alice, "p": iri("http://e.com/a")}, + {"s": alice, "p": iri("http://e.com/b")}, + ] + right = [{"s": alice, "label": name_alice}] + result = hash_join(left, right) + assert len(result) == 2 + + def test_join_preserves_values(self, alice, name_alice): + left = [{"s": alice, "x": lit("1")}] + right = [{"s": alice, "y": lit("2")}] + result = hash_join(left, right) + assert len(result) == 1 + assert result[0]["x"].value == "1" + assert result[0]["y"].value == "2" + + +class TestLeftJoin: + + def test_left_join_with_matches(self, alice, bob, name_alice): + left = [{"s": alice}, {"s": bob}] + right = [{"s": alice, "label": name_alice}] + result = left_join(left, right) + assert len(result) == 2 + # Alice has label + alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"] + assert len(alice_sols) == 1 + assert "label" in alice_sols[0] + # Bob preserved without label + bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"] + assert len(bob_sols) == 1 + assert "label" not in bob_sols[0] + + def test_left_join_no_matches(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob, "label": lit("Bob")}] + result = left_join(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/alice" + assert "label" not in result[0] + + def test_left_join_empty_right(self, alice): + left = [{"s": alice}] + result = left_join(left, []) + assert len(result) == 1 + + def test_left_join_empty_left(self): + result = left_join([], [{"s": iri("http://x")}]) + assert len(result) == 0 + + def test_left_join_with_filter(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + right = [ + {"s": alice, "val": lit("yes")}, + {"s": bob, "val": lit("no")}, + ] + # Filter: only keep joins where val == "yes" + result = left_join( + left, right, + filter_fn=lambda sol: sol.get("val") and sol["val"].value == "yes" + ) + assert len(result) == 2 + # Alice matches filter + alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"] + assert "val" in alice_sols[0] + assert alice_sols[0]["val"].value == "yes" + # Bob doesn't match filter, preserved without val + bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"] + assert "val" not in bob_sols[0] + + +class TestUnion: + + def test_union_concatenates(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob}] + result = union(left, right) + assert len(result) == 2 + + def test_union_preserves_order(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob}] + result = union(left, right) + assert result[0]["s"].iri == "http://example.com/alice" + assert result[1]["s"].iri == "http://example.com/bob" + + def test_union_empty_left(self, alice): + result = union([], [{"s": alice}]) + assert len(result) == 1 + + def test_union_both_empty(self): + result = union([], []) + assert len(result) == 0 + + def test_union_allows_duplicates(self, alice): + result = union([{"s": alice}], [{"s": alice}]) + assert len(result) == 2 + + +class TestProject: + + def test_project_keeps_selected(self, alice, name_alice): + solutions = [{"s": alice, "label": name_alice, "extra": lit("x")}] + result = project(solutions, ["s", "label"]) + assert len(result) == 1 + assert "s" in result[0] + assert "label" in result[0] + assert "extra" not in result[0] + + def test_project_missing_variable(self, alice): + solutions = [{"s": alice}] + result = project(solutions, ["s", "missing"]) + assert len(result) == 1 + assert "s" in result[0] + assert "missing" not in result[0] + + def test_project_empty(self): + result = project([], ["s"]) + assert len(result) == 0 + + +class TestDistinct: + + def test_removes_duplicates(self, alice): + solutions = [{"s": alice}, {"s": alice}, {"s": alice}] + result = distinct(solutions) + assert len(result) == 1 + + def test_keeps_different(self, alice, bob): + solutions = [{"s": alice}, {"s": bob}] + result = distinct(solutions) + assert len(result) == 2 + + def test_empty(self): + result = distinct([]) + assert len(result) == 0 + + def test_multi_variable_distinct(self, alice, bob): + solutions = [ + {"s": alice, "o": bob}, + {"s": alice, "o": bob}, + {"s": alice, "o": alice}, + ] + result = distinct(solutions) + assert len(result) == 2 + + +class TestOrderBy: + + def test_order_by_ascending(self): + solutions = [ + {"label": lit("Charlie")}, + {"label": lit("Alice")}, + {"label": lit("Bob")}, + ] + key_fns = [(lambda sol: sol.get("label"), True)] + result = order_by(solutions, key_fns) + assert result[0]["label"].value == "Alice" + assert result[1]["label"].value == "Bob" + assert result[2]["label"].value == "Charlie" + + def test_order_by_descending(self): + solutions = [ + {"label": lit("Alice")}, + {"label": lit("Charlie")}, + {"label": lit("Bob")}, + ] + key_fns = [(lambda sol: sol.get("label"), False)] + result = order_by(solutions, key_fns) + assert result[0]["label"].value == "Charlie" + assert result[1]["label"].value == "Bob" + assert result[2]["label"].value == "Alice" + + def test_order_by_empty(self): + result = order_by([], [(lambda sol: sol.get("x"), True)]) + assert len(result) == 0 + + def test_order_by_no_keys(self, alice): + solutions = [{"s": alice}] + result = order_by(solutions, []) + assert len(result) == 1 + + +class TestSlice: + + def test_limit(self, alice, bob, carol): + solutions = [{"s": alice}, {"s": bob}, {"s": carol}] + result = slice_solutions(solutions, limit=2) + assert len(result) == 2 + + def test_offset(self, alice, bob, carol): + solutions = [{"s": alice}, {"s": bob}, {"s": carol}] + result = slice_solutions(solutions, offset=1) + assert len(result) == 2 + assert result[0]["s"].iri == "http://example.com/bob" + + def test_offset_and_limit(self, alice, bob, carol): + solutions = [{"s": alice}, {"s": bob}, {"s": carol}] + result = slice_solutions(solutions, offset=1, limit=1) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" + + def test_limit_zero(self, alice): + result = slice_solutions([{"s": alice}], limit=0) + assert len(result) == 0 + + def test_offset_beyond_length(self, alice): + result = slice_solutions([{"s": alice}], offset=10) + assert len(result) == 0 + + def test_no_slice(self, alice, bob): + solutions = [{"s": alice}, {"s": bob}] + result = slice_solutions(solutions) + assert len(result) == 2 diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index d89e16f6..0aa55347 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -1122,6 +1122,45 @@ class FlowInstance: return result + def sparql_query( + self, query, user="trustgraph", collection="default", + limit=10000 + ): + """ + Execute a SPARQL query against the knowledge graph. + + Args: + query: SPARQL 1.1 query string + user: User/keyspace identifier (default: "trustgraph") + collection: Collection identifier (default: "default") + limit: Safety limit on results (default: 10000) + + Returns: + dict with query results. Structure depends on query type: + - SELECT: {"query-type": "select", "variables": [...], "bindings": [...]} + - ASK: {"query-type": "ask", "ask-result": bool} + - CONSTRUCT/DESCRIBE: {"query-type": "construct", "triples": [...]} + + Raises: + ProtocolException: If an error occurs + """ + + input = { + "query": query, + "user": user, + "collection": collection, + "limit": limit, + } + + response = self.request("service/sparql", input) + + if "error" in response and response["error"]: + error_type = response["error"].get("type", "unknown") + error_message = response["error"].get("message", "Unknown error") + raise ProtocolException(f"{error_type}: {error_message}") + + return response + def nlp_query(self, question, max_results=100): """ Convert a natural language question to a GraphQL query. diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 9fbcbf16..30f5061c 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -27,6 +27,7 @@ from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, Q from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator +from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator # Register all service translators TranslatorRegistry.register_service( @@ -149,6 +150,12 @@ TranslatorRegistry.register_service( CollectionManagementResponseTranslator() ) +TranslatorRegistry.register_service( + "sparql-query", + SparqlQueryRequestTranslator(), + SparqlQueryResponseTranslator() +) + # Register single-direction translators for document loading TranslatorRegistry.register_request("document", DocumentTranslator()) TranslatorRegistry.register_request("text-document", TextDocumentTranslator()) diff --git a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py new file mode 100644 index 00000000..d1912429 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py @@ -0,0 +1,111 @@ +from typing import Dict, Any, Tuple +from ...schema import ( + SparqlQueryRequest, SparqlQueryResponse, SparqlBinding, + Error, Term, Triple, IRI, LITERAL, BLANK, +) +from .base import MessageTranslator +from .primitives import TermTranslator, TripleTranslator + + +class SparqlQueryRequestTranslator(MessageTranslator): + """Translator for SparqlQueryRequest schema objects.""" + + def decode(self, data: Dict[str, Any]) -> SparqlQueryRequest: + return SparqlQueryRequest( + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + query=data.get("query", ""), + limit=int(data.get("limit", 10000)), + ) + + def encode(self, obj: SparqlQueryRequest) -> Dict[str, Any]: + return { + "user": obj.user, + "collection": obj.collection, + "query": obj.query, + "limit": obj.limit, + } + + +class SparqlQueryResponseTranslator(MessageTranslator): + """Translator for SparqlQueryResponse schema objects.""" + + def __init__(self): + self.term_translator = TermTranslator() + self.triple_translator = TripleTranslator() + + def decode(self, data: Dict[str, Any]) -> SparqlQueryResponse: + raise NotImplementedError( + "Response translation to schema not typically needed" + ) + + def _encode_term(self, v): + """Encode a Term, handling both Term objects and dicts from + pub/sub deserialization.""" + if v is None: + return None + if isinstance(v, dict): + # Reconstruct Term from dict (pub/sub deserializes nested + # dataclasses as dicts) + term = Term( + type=v.get("type", ""), + iri=v.get("iri", ""), + id=v.get("id", ""), + value=v.get("value", ""), + datatype=v.get("datatype", ""), + language=v.get("language", ""), + ) + return self.term_translator.encode(term) + return self.term_translator.encode(v) + + def _encode_error(self, error): + """Encode an Error, handling both Error objects and dicts.""" + if isinstance(error, dict): + return { + "type": error.get("type", ""), + "message": error.get("message", ""), + } + return { + "type": error.type, + "message": error.message, + } + + def encode(self, obj: SparqlQueryResponse) -> Dict[str, Any]: + result = { + "query-type": obj.query_type, + } + + if obj.error: + result["error"] = self._encode_error(obj.error) + + if obj.query_type == "select": + result["variables"] = obj.variables + bindings = [] + for binding in obj.bindings: + # binding may be a SparqlBinding or a dict + if isinstance(binding, dict): + values = binding.get("values", []) + else: + values = binding.values + bindings.append({ + "values": [ + self._encode_term(v) for v in values + ] + }) + result["bindings"] = bindings + + elif obj.query_type == "ask": + result["ask-result"] = obj.ask_result + + elif obj.query_type in ("construct", "describe"): + result["triples"] = [ + self.triple_translator.encode(t) + for t in obj.triples + ] + + return result + + def encode_with_completion( + self, obj: SparqlQueryResponse + ) -> Tuple[Dict[str, Any], bool]: + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index f246bc31..550b7d12 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -13,4 +13,5 @@ from .rows_query import * from .diagnosis import * from .collection import * from .storage import * -from .tool_service import * \ No newline at end of file +from .tool_service import * +from .sparql_query import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/sparql_query.py b/trustgraph-base/trustgraph/schema/services/sparql_query.py new file mode 100644 index 00000000..105cc753 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/sparql_query.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, field + +from ..core.primitives import Error, Term, Triple +from ..core.topic import queue + +############################################################################ + +# SPARQL query + +@dataclass +class SparqlBinding: + """A single row of SPARQL SELECT results. + Values are ordered to match the variables list in SparqlQueryResponse. + """ + values: list[Term | None] = field(default_factory=list) + +@dataclass +class SparqlQueryRequest: + user: str = "" + collection: str = "" + query: str = "" # SPARQL query string + limit: int = 10000 # Safety limit on results + +@dataclass +class SparqlQueryResponse: + error: Error | None = None + query_type: str = "" # "select", "ask", "construct", "describe" + + # For SELECT queries + variables: list[str] = field(default_factory=list) + bindings: list[SparqlBinding] = field(default_factory=list) + + # For ASK queries + ask_result: bool = False + + # For CONSTRUCT/DESCRIBE queries + triples: list[Triple] = field(default_factory=list) + +sparql_query_request_queue = queue('sparql-query', cls='request') +sparql_query_response_queue = queue('sparql-query', cls='response') diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 9fd6bed7..2b111cae 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -51,6 +51,7 @@ tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main" tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main" tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main" tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main" +tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main" tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main" tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py new file mode 100644 index 00000000..9547193d --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -0,0 +1,230 @@ +""" +Execute a SPARQL query against the TrustGraph knowledge graph. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = 'trustgraph' +default_collection = 'default' + + +def format_select(response, output_format): + """Format SELECT query results.""" + variables = response.get("variables", []) + bindings = response.get("bindings", []) + + if not bindings: + return "No results." + + if output_format == "json": + rows = [] + for binding in bindings: + row = {} + for var, val in zip(variables, binding.get("values", [])): + if val is None: + row[var] = None + elif val.get("t") == "i": + row[var] = val.get("i", "") + elif val.get("t") == "l": + row[var] = val.get("v", "") + else: + row[var] = val.get("v", val.get("i", "")) + rows.append(row) + return json.dumps(rows, indent=2) + + # Table format + col_widths = [len(v) for v in variables] + rows = [] + for binding in bindings: + row = [] + for i, val in enumerate(binding.get("values", [])): + if val is None: + cell = "" + elif val.get("t") == "i": + cell = val.get("i", "") + elif val.get("t") == "l": + cell = val.get("v", "") + else: + cell = val.get("v", val.get("i", "")) + row.append(cell) + if i < len(col_widths): + col_widths[i] = max(col_widths[i], len(cell)) + rows.append(row) + + # Build table + header = " | ".join( + v.ljust(col_widths[i]) for i, v in enumerate(variables) + ) + separator = "-+-".join("-" * w for w in col_widths) + lines = [header, separator] + for row in rows: + line = " | ".join( + cell.ljust(col_widths[i]) if i < len(col_widths) else cell + for i, cell in enumerate(row) + ) + lines.append(line) + return "\n".join(lines) + + +def format_triples(response, output_format): + """Format CONSTRUCT/DESCRIBE results.""" + triples = response.get("triples", []) + + if not triples: + return "No triples." + + if output_format == "json": + return json.dumps(triples, indent=2) + + lines = [] + for t in triples: + s = _term_str(t.get("s")) + p = _term_str(t.get("p")) + o = _term_str(t.get("o")) + lines.append(f"{s} {p} {o} .") + return "\n".join(lines) + + +def _term_str(val): + """Convert a wire-format term to a display string.""" + if val is None: + return "?" + t = val.get("t", "") + if t == "i": + return f"<{val.get('i', '')}>" + elif t == "l": + v = val.get("v", "") + dt = val.get("d", "") + lang = val.get("l", "") + if lang: + return f'"{v}"@{lang}' + elif dt: + return f'"{v}"^^<{dt}>' + return f'"{v}"' + return str(val) + + +def sparql_query(url, token, flow_id, query, user, collection, limit, + output_format): + + api = Api(url=url, token=token).flow().id(flow_id) + + resp = api.sparql_query( + query=query, + user=user, + collection=collection, + limit=limit, + ) + + query_type = resp.get("query-type", "select") + + if query_type == "select": + print(format_select(resp, output_format)) + elif query_type == "ask": + print("true" if resp.get("ask-result") else "false") + elif query_type in ("construct", "describe"): + print(format_triples(resp, output_format)) + else: + print(json.dumps(resp, indent=2)) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-sparql-query', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=os.getenv("TRUSTGRAPH_TOKEN"), + help='API bearer token (default: TRUSTGRAPH_TOKEN env var)', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help='Flow ID (default: default)', + ) + + parser.add_argument( + '-q', '--query', + help='SPARQL query string', + ) + + parser.add_argument( + '-i', '--input', + help='Read SPARQL query from file (use - for stdin)', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})', + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})', + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10000, + help='Result limit (default: 10000)', + ) + + parser.add_argument( + '--format', + choices=['table', 'json'], + default='table', + help='Output format (default: table)', + ) + + args = parser.parse_args() + + # Get query from argument or file + query = args.query + if not query and args.input: + if args.input == '-': + query = sys.stdin.read() + else: + with open(args.input) as f: + query = f.read() + + if not query: + parser.error("Either -q/--query or -i/--input is required") + + try: + + sparql_query( + url=args.url, + token=args.token, + flow_id=args.flow_id, + query=query, + user=args.user, + collection=args.collection, + limit=args.limit, + output_format=args.format, + ) + + except Exception as e: + print(f"Exception: {e}", flush=True, file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 66363305..b2df4a4c 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -101,6 +101,7 @@ pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run" prompt-template = "trustgraph.prompt.template:run" rev-gateway = "trustgraph.rev_gateway:run" run-processing = "trustgraph.processing:run" +sparql-query = "trustgraph.query.sparql:run" structured-query = "trustgraph.retrieval.structured_query:run" structured-diag = "trustgraph.retrieval.structured_diag:run" text-completion-azure = "trustgraph.model.text_completion.azure:run" diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index d068ecef..a4bf8de9 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -22,6 +22,7 @@ from . document_rag import DocumentRagRequestor from . triples_query import TriplesQueryRequestor from . rows_query import RowsQueryRequestor from . nlp_query import NLPQueryRequestor +from . sparql_query import SparqlQueryRequestor from . structured_query import StructuredQueryRequestor from . structured_diag import StructuredDiagRequestor from . embeddings import EmbeddingsRequestor @@ -65,6 +66,7 @@ request_response_dispatchers = { "structured-query": StructuredQueryRequestor, "structured-diag": StructuredDiagRequestor, "row-embeddings": RowEmbeddingsQueryRequestor, + "sparql": SparqlQueryRequestor, } global_dispatchers = { diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/sparql_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/sparql_query.py new file mode 100644 index 00000000..f81b9df6 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/sparql_query.py @@ -0,0 +1,30 @@ +from ... schema import SparqlQueryRequest, SparqlQueryResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class SparqlQueryRequestor(ServiceRequestor): + def __init__( + self, backend, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(SparqlQueryRequestor, self).__init__( + backend=backend, + request_queue=request_queue, + response_queue=response_queue, + request_schema=SparqlQueryRequest, + response_schema=SparqlQueryResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("sparql-query") + self.response_translator = TranslatorRegistry.get_response_translator("sparql-query") + + def to_request(self, body): + return self.request_translator.decode(body) + + def from_response(self, message): + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/query/sparql/__init__.py b/trustgraph-flow/trustgraph/query/sparql/__init__.py new file mode 100644 index 00000000..98f4d9da --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/__init__.py @@ -0,0 +1 @@ +from . service import * diff --git a/trustgraph-flow/trustgraph/query/sparql/__main__.py b/trustgraph-flow/trustgraph/query/sparql/__main__.py new file mode 100644 index 00000000..da5a9021 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py new file mode 100644 index 00000000..eda83efb --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -0,0 +1,541 @@ +""" +SPARQL algebra evaluator. + +Recursively evaluates an rdflib SPARQL algebra tree by issuing triple +pattern queries via TriplesClient (streaming) and performing in-memory +joins, filters, and projections. +""" + +import logging +from collections import defaultdict + +from rdflib.term import Variable, URIRef, Literal, BNode +from rdflib.plugins.sparql.parserutils import CompValue + +from ... schema import Term, Triple, IRI, LITERAL, BLANK +from ... knowledge import Uri +from ... knowledge import Literal as KgLiteral +from . parser import rdflib_term_to_term +from . solutions import ( + hash_join, left_join, union, project, distinct, + order_by, slice_solutions, _term_key, +) +from . expressions import evaluate_expression, _effective_boolean + +logger = logging.getLogger(__name__) + + +class EvaluationError(Exception): + """Raised when SPARQL evaluation fails.""" + pass + + +async def evaluate(node, triples_client, user, collection, limit=10000): + """ + Evaluate a SPARQL algebra node. + + Args: + node: rdflib CompValue algebra node + triples_client: TriplesClient instance for triple pattern queries + user: user/keyspace identifier + collection: collection identifier + limit: safety limit on results + + Returns: + list of solutions (dicts mapping variable names to Term values) + """ + if not isinstance(node, CompValue): + logger.warning(f"Expected CompValue, got {type(node)}: {node}") + return [{}] + + name = node.name + handler = _HANDLERS.get(name) + + if handler is None: + logger.warning(f"Unsupported algebra node: {name}") + return [{}] + + return await handler(node, triples_client, user, collection, limit) + + +# --- Node handlers --- + +async def _eval_select_query(node, tc, user, collection, limit): + """Evaluate a SelectQuery node.""" + return await evaluate(node.p, tc, user, collection, limit) + + +async def _eval_project(node, tc, user, collection, limit): + """Evaluate a Project node (SELECT variable projection).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + variables = [str(v) for v in node.PV] + return project(solutions, variables) + + +async def _eval_bgp(node, tc, user, collection, limit): + """ + Evaluate a Basic Graph Pattern. + + Issues streaming triple pattern queries and joins results. Patterns + are ordered by selectivity (more bound terms first) and evaluated + sequentially with bound-variable substitution. + """ + triples = node.triples + if not triples: + return [{}] + + # Sort patterns by selectivity: more bound terms = more selective + def selectivity(pattern): + return sum(1 for t in pattern if not isinstance(t, Variable)) + + sorted_patterns = sorted( + enumerate(triples), key=lambda x: -selectivity(x[1]) + ) + + solutions = [{}] + + for _, pattern in sorted_patterns: + s_tmpl, p_tmpl, o_tmpl = pattern + + new_solutions = [] + + for sol in solutions: + # Substitute known bindings into the pattern + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + # Query the triples store + results = await _query_pattern( + tc, s_val, p_val, o_val, user, collection, limit + ) + + # Map results back to variable bindings, + # converting Uri/Literal to Term objects + for triple in results: + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) + + solutions = new_solutions + + if not solutions: + break + + return solutions[:limit] + + +async def _eval_join(node, tc, user, collection, limit): + """Evaluate a Join node.""" + left = await evaluate(node.p1, tc, user, collection, limit) + right = await evaluate(node.p2, tc, user, collection, limit) + return hash_join(left, right)[:limit] + + +async def _eval_left_join(node, tc, user, collection, limit): + """Evaluate a LeftJoin node (OPTIONAL).""" + left_sols = await evaluate(node.p1, tc, user, collection, limit) + right_sols = await evaluate(node.p2, tc, user, collection, limit) + + filter_fn = None + if hasattr(node, "expr") and node.expr is not None: + expr = node.expr + if not (isinstance(expr, CompValue) and expr.name == "TrueFilter"): + filter_fn = lambda sol: _effective_boolean( + evaluate_expression(expr, sol) + ) + + return left_join(left_sols, right_sols, filter_fn)[:limit] + + +async def _eval_union(node, tc, user, collection, limit): + """Evaluate a Union node.""" + left = await evaluate(node.p1, tc, user, collection, limit) + right = await evaluate(node.p2, tc, user, collection, limit) + return union(left, right)[:limit] + + +async def _eval_filter(node, tc, user, collection, limit): + """Evaluate a Filter node.""" + solutions = await evaluate(node.p, tc, user, collection, limit) + expr = node.expr + return [ + sol for sol in solutions + if _effective_boolean(evaluate_expression(expr, sol)) + ] + + +async def _eval_distinct(node, tc, user, collection, limit): + """Evaluate a Distinct node.""" + solutions = await evaluate(node.p, tc, user, collection, limit) + return distinct(solutions) + + +async def _eval_reduced(node, tc, user, collection, limit): + """Evaluate a Reduced node (like Distinct but implementation-defined).""" + # Treat same as Distinct + solutions = await evaluate(node.p, tc, user, collection, limit) + return distinct(solutions) + + +async def _eval_order_by(node, tc, user, collection, limit): + """Evaluate an OrderBy node.""" + solutions = await evaluate(node.p, tc, user, collection, limit) + + key_fns = [] + for cond in node.expr: + if isinstance(cond, CompValue) and cond.name == "OrderCondition": + ascending = cond.order != "DESC" + expr = cond.expr + key_fns.append(( + lambda sol, e=expr: evaluate_expression(e, sol), + ascending, + )) + else: + # Simple variable or expression + key_fns.append(( + lambda sol, e=cond: evaluate_expression(e, sol), + True, + )) + + return order_by(solutions, key_fns) + + +async def _eval_slice(node, tc, user, collection, limit): + """Evaluate a Slice node (LIMIT/OFFSET).""" + # Pass tighter limit downstream if possible + inner_limit = limit + if node.length is not None: + offset = node.start or 0 + inner_limit = min(limit, offset + node.length) + + solutions = await evaluate(node.p, tc, user, collection, inner_limit) + return slice_solutions(solutions, node.start or 0, node.length) + + +async def _eval_extend(node, tc, user, collection, limit): + """Evaluate an Extend node (BIND).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + var_name = str(node.var) + expr = node.expr + + result = [] + for sol in solutions: + val = evaluate_expression(expr, sol) + new_sol = dict(sol) + if isinstance(val, Term): + new_sol[var_name] = val + elif isinstance(val, (int, float)): + new_sol[var_name] = Term(type=LITERAL, value=str(val)) + elif isinstance(val, str): + new_sol[var_name] = Term(type=LITERAL, value=val) + elif isinstance(val, bool): + new_sol[var_name] = Term( + type=LITERAL, value=str(val).lower(), + datatype="http://www.w3.org/2001/XMLSchema#boolean" + ) + elif val is not None: + new_sol[var_name] = Term(type=LITERAL, value=str(val)) + result.append(new_sol) + + return result + + +async def _eval_group(node, tc, user, collection, limit): + """Evaluate a Group node (GROUP BY with aggregation).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + + # Extract grouping expressions + group_exprs = [] + if hasattr(node, "expr") and node.expr: + for expr in node.expr: + if isinstance(expr, CompValue) and expr.name == "GroupAs": + group_exprs.append((expr.expr, str(expr.var) if hasattr(expr, "var") and expr.var else None)) + elif isinstance(expr, Variable): + group_exprs.append((expr, str(expr))) + else: + group_exprs.append((expr, None)) + + # Group solutions + groups = defaultdict(list) + for sol in solutions: + key_parts = [] + for expr, _ in group_exprs: + val = evaluate_expression(expr, sol) + key_parts.append(_term_key(val) if isinstance(val, Term) else val) + groups[tuple(key_parts)].append(sol) + + if not group_exprs: + # No GROUP BY - entire result is one group + groups[()].extend(solutions) + + # Build grouped solutions (one per group) + result = [] + for key, group_sols in groups.items(): + sol = {} + # Include group key variables + if group_sols: + for (expr, var_name), k in zip(group_exprs, key): + if var_name and group_sols: + sol[var_name] = evaluate_expression(expr, group_sols[0]) + sol["__group__"] = group_sols + result.append(sol) + + return result + + +async def _eval_aggregate_join(node, tc, user, collection, limit): + """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + + result = [] + for sol in solutions: + group = sol.get("__group__", [sol]) + new_sol = {k: v for k, v in sol.items() if k != "__group__"} + + # Apply aggregate functions + if hasattr(node, "A") and node.A: + for agg in node.A: + var_name = str(agg.res) + agg_val = _compute_aggregate(agg, group) + new_sol[var_name] = agg_val + + result.append(new_sol) + + return result + + +async def _eval_graph(node, tc, user, collection, limit): + """Evaluate a Graph node (GRAPH clause).""" + term = node.term + + if isinstance(term, URIRef): + # GRAPH { ... } — fixed graph + # We'd need to pass graph to triples queries + # For now, evaluate inner pattern normally + logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") + return await evaluate(node.p, tc, user, collection, limit) + elif isinstance(term, Variable): + # GRAPH ?g { ... } — variable graph + logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") + return await evaluate(node.p, tc, user, collection, limit) + else: + return await evaluate(node.p, tc, user, collection, limit) + + +async def _eval_values(node, tc, user, collection, limit): + """Evaluate a VALUES clause (inline data).""" + variables = [str(v) for v in node.var] + solutions = [] + + for row in node.value: + sol = {} + for var_name, val in zip(variables, row): + if val is not None and str(val) != "UNDEF": + sol[var_name] = rdflib_term_to_term(val) + solutions.append(sol) + + return solutions + + +async def _eval_to_multiset(node, tc, user, collection, limit): + """Evaluate a ToMultiSet node (subquery).""" + return await evaluate(node.p, tc, user, collection, limit) + + +# --- Aggregate computation --- + +def _compute_aggregate(agg, group): + """Compute a single aggregate function over a group of solutions.""" + agg_name = agg.name if hasattr(agg, "name") else "" + + # Get the expression to aggregate + expr = agg.vars if hasattr(agg, "vars") else None + + if agg_name == "Aggregate_Count": + if hasattr(agg, "distinct") and agg.distinct: + vals = set() + for sol in group: + if expr: + val = evaluate_expression(expr, sol) + if val is not None: + vals.add(_term_key(val) if isinstance(val, Term) else val) + else: + vals.add(id(sol)) + return Term(type=LITERAL, value=str(len(vals)), + datatype="http://www.w3.org/2001/XMLSchema#integer") + return Term(type=LITERAL, value=str(len(group)), + datatype="http://www.w3.org/2001/XMLSchema#integer") + + if agg_name == "Aggregate_Sum": + total = 0 + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + num = _try_numeric(val) + if num is not None: + total += num + return Term(type=LITERAL, value=str(total), + datatype="http://www.w3.org/2001/XMLSchema#decimal") + + if agg_name == "Aggregate_Avg": + total = 0 + count = 0 + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + num = _try_numeric(val) + if num is not None: + total += num + count += 1 + avg = total / count if count > 0 else 0 + return Term(type=LITERAL, value=str(avg), + datatype="http://www.w3.org/2001/XMLSchema#decimal") + + if agg_name == "Aggregate_Min": + min_val = None + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + if val is not None: + cmp = _term_key(val) if isinstance(val, Term) else val + if min_val is None or cmp < min_val[0]: + min_val = (cmp, val) + if min_val: + val = min_val[1] + if isinstance(val, Term): + return val + return Term(type=LITERAL, value=str(val)) + return None + + if agg_name == "Aggregate_Max": + max_val = None + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + if val is not None: + cmp = _term_key(val) if isinstance(val, Term) else val + if max_val is None or cmp > max_val[0]: + max_val = (cmp, val) + if max_val: + val = max_val[1] + if isinstance(val, Term): + return val + return Term(type=LITERAL, value=str(val)) + return None + + if agg_name == "Aggregate_GroupConcat": + separator = agg.separator if hasattr(agg, "separator") else " " + vals = [] + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + if val is not None: + if isinstance(val, Term): + vals.append(val.value if val.type == LITERAL else val.iri) + else: + vals.append(str(val)) + return Term(type=LITERAL, value=separator.join(vals)) + + if agg_name == "Aggregate_Sample": + if group: + val = evaluate_expression(expr, group[0]) if expr else None + if isinstance(val, Term): + return val + if val is not None: + return Term(type=LITERAL, value=str(val)) + return None + + logger.warning(f"Unsupported aggregate: {agg_name}") + return None + + +# --- Helper functions --- + +def _to_term(val): + """ + Convert a value to a schema Term. Handles Uri and Literal from the + knowledge module (returned by TriplesClient) as well as plain strings. + """ + if val is None: + return None + if isinstance(val, Term): + return val + if isinstance(val, Uri): + return Term(type=IRI, iri=str(val)) + if isinstance(val, KgLiteral): + return Term(type=LITERAL, value=str(val)) + if isinstance(val, str): + if val.startswith("http://") or val.startswith("https://") or val.startswith("urn:"): + return Term(type=IRI, iri=val) + return Term(type=LITERAL, value=val) + return Term(type=LITERAL, value=str(val)) + + +def _resolve_term(tmpl, solution): + """ + Resolve a triple pattern term. If it's a variable and bound in the + solution, return the bound Term. Otherwise return None (wildcard) + for variables, or convert concrete terms. + """ + if isinstance(tmpl, Variable): + name = str(tmpl) + if name in solution: + return solution[name] + return None + else: + return rdflib_term_to_term(tmpl) + + +async def _query_pattern(tc, s, p, o, user, collection, limit): + """ + Issue a streaming triple pattern query via TriplesClient. + + Returns a list of Triple-like objects with s, p, o attributes. + """ + results = await tc.query( + s=s, p=p, o=o, + limit=limit, + user=user, + collection=collection, + ) + return results + + +def _try_numeric(val): + """Try to convert a value to a number, return None on failure.""" + if val is None: + return None + if isinstance(val, (int, float)): + return val + if isinstance(val, Term) and val.type == LITERAL: + try: + if "." in val.value: + return float(val.value) + return int(val.value) + except (ValueError, TypeError): + return None + return None + + +# --- Handler registry --- + +_HANDLERS = { + "SelectQuery": _eval_select_query, + "Project": _eval_project, + "BGP": _eval_bgp, + "Join": _eval_join, + "LeftJoin": _eval_left_join, + "Union": _eval_union, + "Filter": _eval_filter, + "Distinct": _eval_distinct, + "Reduced": _eval_reduced, + "OrderBy": _eval_order_by, + "Slice": _eval_slice, + "Extend": _eval_extend, + "Group": _eval_group, + "AggregateJoin": _eval_aggregate_join, + "Graph": _eval_graph, + "values": _eval_values, + "ToMultiSet": _eval_to_multiset, +} diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py new file mode 100644 index 00000000..eac1199c --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -0,0 +1,481 @@ +""" +SPARQL FILTER expression evaluator. + +Evaluates rdflib algebra expression nodes against a solution (variable +binding) to produce a value or boolean result. +""" + +import re +import logging +import operator + +from rdflib.term import Variable, URIRef, Literal, BNode +from rdflib.plugins.sparql.parserutils import CompValue + +from ... schema import Term, IRI, LITERAL, BLANK +from . parser import rdflib_term_to_term + +logger = logging.getLogger(__name__) + + +class ExpressionError(Exception): + """Raised when a SPARQL expression cannot be evaluated.""" + pass + + +def evaluate_expression(expr, solution): + """ + Evaluate a SPARQL expression against a solution binding. + + Args: + expr: rdflib algebra expression node + solution: dict mapping variable names to Term values + + Returns: + The result value (Term, bool, number, string, or None) + """ + if expr is None: + return True + + # rdflib Variable + if isinstance(expr, Variable): + name = str(expr) + return solution.get(name) + + # rdflib concrete terms + if isinstance(expr, URIRef): + return Term(type=IRI, iri=str(expr)) + + if isinstance(expr, Literal): + return rdflib_term_to_term(expr) + + if isinstance(expr, BNode): + return Term(type=BLANK, id=str(expr)) + + # Boolean constants + if isinstance(expr, bool): + return expr + + # Numeric constants + if isinstance(expr, (int, float)): + return expr + + # String constants + if isinstance(expr, str): + return expr + + # CompValue nodes from rdflib algebra + if isinstance(expr, CompValue): + return _evaluate_comp_value(expr, solution) + + # List/tuple (e.g. function arguments) + if isinstance(expr, (list, tuple)): + return [evaluate_expression(e, solution) for e in expr] + + logger.warning(f"Unknown expression type: {type(expr)}: {expr}") + return None + + +def _evaluate_comp_value(node, solution): + """Evaluate a CompValue expression node.""" + name = node.name + + # Relational expressions: =, !=, <, >, <=, >= + if name == "RelationalExpression": + return _eval_relational(node, solution) + + # Conditional AND / OR + if name == "ConditionalAndExpression": + return _eval_conditional_and(node, solution) + + if name == "ConditionalOrExpression": + return _eval_conditional_or(node, solution) + + # Unary NOT + if name == "UnaryNot": + val = evaluate_expression(node.expr, solution) + return not _effective_boolean(val) + + # Unary plus/minus + if name == "UnaryPlus": + return _to_numeric(evaluate_expression(node.expr, solution)) + + if name == "UnaryMinus": + val = _to_numeric(evaluate_expression(node.expr, solution)) + return -val if val is not None else None + + # Arithmetic + if name == "AdditiveExpression": + return _eval_additive(node, solution) + + if name == "MultiplicativeExpression": + return _eval_multiplicative(node, solution) + + # SPARQL built-in functions + if name.startswith("Builtin_"): + return _eval_builtin(name, node, solution) + + # Function call + if name == "Function": + return _eval_function(node, solution) + + # Exists / NotExists + if name == "Builtin_EXISTS": + # EXISTS requires graph pattern evaluation - not handled here + logger.warning("EXISTS not supported in filter expressions") + return True + + if name == "Builtin_NOTEXISTS": + logger.warning("NOT EXISTS not supported in filter expressions") + return True + + # TrueFilter (used with OPTIONAL) + if name == "TrueFilter": + return True + + # IN / NOT IN + if name == "Builtin_IN": + return _eval_in(node, solution) + + if name == "Builtin_NOTIN": + return not _eval_in(node, solution) + + logger.warning(f"Unknown CompValue expression: {name}") + return None + + +def _eval_relational(node, solution): + """Evaluate a relational expression (=, !=, <, >, <=, >=).""" + left = evaluate_expression(node.expr, solution) + right = evaluate_expression(node.other, solution) + op = node.op + + if left is None or right is None: + return False + + left_cmp = _comparable_value(left) + right_cmp = _comparable_value(right) + + ops = { + "=": operator.eq, "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + ">": operator.gt, + "<=": operator.le, + ">=": operator.ge, + } + + op_fn = ops.get(str(op)) + if op_fn is None: + logger.warning(f"Unknown relational operator: {op}") + return False + + try: + return op_fn(left_cmp, right_cmp) + except TypeError: + return False + + +def _eval_conditional_and(node, solution): + """Evaluate AND expression.""" + result = _effective_boolean(evaluate_expression(node.expr, solution)) + if not result: + return False + for other in node.other: + result = _effective_boolean(evaluate_expression(other, solution)) + if not result: + return False + return True + + +def _eval_conditional_or(node, solution): + """Evaluate OR expression.""" + result = _effective_boolean(evaluate_expression(node.expr, solution)) + if result: + return True + for other in node.other: + result = _effective_boolean(evaluate_expression(other, solution)) + if result: + return True + return False + + +def _eval_additive(node, solution): + """Evaluate additive expression (a + b - c ...).""" + result = _to_numeric(evaluate_expression(node.expr, solution)) + if result is None: + return None + for op, operand in zip(node.op, node.other): + val = _to_numeric(evaluate_expression(operand, solution)) + if val is None: + return None + if str(op) == "+": + result = result + val + elif str(op) == "-": + result = result - val + return result + + +def _eval_multiplicative(node, solution): + """Evaluate multiplicative expression (a * b / c ...).""" + result = _to_numeric(evaluate_expression(node.expr, solution)) + if result is None: + return None + for op, operand in zip(node.op, node.other): + val = _to_numeric(evaluate_expression(operand, solution)) + if val is None: + return None + if str(op) == "*": + result = result * val + elif str(op) == "/": + if val == 0: + return None + result = result / val + return result + + +def _eval_builtin(name, node, solution): + """Evaluate SPARQL built-in functions.""" + builtin = name[len("Builtin_"):] + + if builtin == "BOUND": + var_name = str(node.arg) + return var_name in solution and solution[var_name] is not None + + if builtin == "isIRI" or builtin == "isURI": + val = evaluate_expression(node.arg, solution) + return isinstance(val, Term) and val.type == IRI + + if builtin == "isLITERAL": + val = evaluate_expression(node.arg, solution) + return isinstance(val, Term) and val.type == LITERAL + + if builtin == "isBLANK": + val = evaluate_expression(node.arg, solution) + return isinstance(val, Term) and val.type == BLANK + + if builtin == "STR": + val = evaluate_expression(node.arg, solution) + return Term(type=LITERAL, value=_to_string(val)) + + if builtin == "LANG": + val = evaluate_expression(node.arg, solution) + if isinstance(val, Term) and val.type == LITERAL: + return Term(type=LITERAL, value=val.language or "") + return Term(type=LITERAL, value="") + + if builtin == "DATATYPE": + val = evaluate_expression(node.arg, solution) + if isinstance(val, Term) and val.type == LITERAL and val.datatype: + return Term(type=IRI, iri=val.datatype) + return Term(type=IRI, iri="http://www.w3.org/2001/XMLSchema#string") + + if builtin == "REGEX": + text = _to_string(evaluate_expression(node.text, solution)) + pattern = _to_string(evaluate_expression(node.pattern, solution)) + flags_str = "" + if hasattr(node, "flags") and node.flags is not None: + flags_str = _to_string(evaluate_expression(node.flags, solution)) + + re_flags = 0 + if "i" in flags_str: + re_flags |= re.IGNORECASE + if "m" in flags_str: + re_flags |= re.MULTILINE + if "s" in flags_str: + re_flags |= re.DOTALL + + try: + return bool(re.search(pattern, text, re_flags)) + except re.error: + return False + + if builtin == "STRLEN": + val = _to_string(evaluate_expression(node.arg, solution)) + return len(val) + + if builtin == "UCASE": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=LITERAL, value=val.upper()) + + if builtin == "LCASE": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=LITERAL, value=val.lower()) + + if builtin == "CONTAINS": + string = _to_string(evaluate_expression(node.arg1, solution)) + pattern = _to_string(evaluate_expression(node.arg2, solution)) + return pattern in string + + if builtin == "STRSTARTS": + string = _to_string(evaluate_expression(node.arg1, solution)) + prefix = _to_string(evaluate_expression(node.arg2, solution)) + return string.startswith(prefix) + + if builtin == "STRENDS": + string = _to_string(evaluate_expression(node.arg1, solution)) + suffix = _to_string(evaluate_expression(node.arg2, solution)) + return string.endswith(suffix) + + if builtin == "CONCAT": + args = [_to_string(evaluate_expression(a, solution)) for a in node.arg] + return Term(type=LITERAL, value="".join(args)) + + if builtin == "IF": + cond = _effective_boolean(evaluate_expression(node.arg1, solution)) + if cond: + return evaluate_expression(node.arg2, solution) + else: + return evaluate_expression(node.arg3, solution) + + if builtin == "COALESCE": + for arg in node.arg: + val = evaluate_expression(arg, solution) + if val is not None: + return val + return None + + if builtin == "sameTerm": + left = evaluate_expression(node.arg1, solution) + right = evaluate_expression(node.arg2, solution) + if not isinstance(left, Term) or not isinstance(right, Term): + return False + from . solutions import _term_key + return _term_key(left) == _term_key(right) + + logger.warning(f"Unsupported built-in function: {builtin}") + return None + + +def _eval_function(node, solution): + """Evaluate a SPARQL function call.""" + # Cast functions (xsd:integer, xsd:string, etc.) + iri = str(node.iri) if hasattr(node, "iri") else "" + args = [evaluate_expression(a, solution) for a in node.expr] + + xsd = "http://www.w3.org/2001/XMLSchema#" + if iri == xsd + "integer": + try: + return int(_to_numeric(args[0])) + except (TypeError, ValueError): + return None + elif iri == xsd + "decimal" or iri == xsd + "double" or iri == xsd + "float": + try: + return float(_to_numeric(args[0])) + except (TypeError, ValueError): + return None + elif iri == xsd + "string": + return Term(type=LITERAL, value=_to_string(args[0])) + elif iri == xsd + "boolean": + return _effective_boolean(args[0]) + + logger.warning(f"Unsupported function: {iri}") + return None + + +def _eval_in(node, solution): + """Evaluate IN expression.""" + val = evaluate_expression(node.expr, solution) + for item in node.other: + other = evaluate_expression(item, solution) + if _comparable_value(val) == _comparable_value(other): + return True + return False + + +# --- Value conversion helpers --- + +def _effective_boolean(val): + """Convert a value to its effective boolean value (EBV).""" + if isinstance(val, bool): + return val + if val is None: + return False + if isinstance(val, (int, float)): + return val != 0 + if isinstance(val, str): + return len(val) > 0 + if isinstance(val, Term): + if val.type == LITERAL: + v = val.value + if val.datatype == "http://www.w3.org/2001/XMLSchema#boolean": + return v.lower() in ("true", "1") + if val.datatype in ( + "http://www.w3.org/2001/XMLSchema#integer", + "http://www.w3.org/2001/XMLSchema#decimal", + "http://www.w3.org/2001/XMLSchema#double", + "http://www.w3.org/2001/XMLSchema#float", + ): + try: + return float(v) != 0 + except ValueError: + return False + return len(v) > 0 + return True + return bool(val) + + +def _to_string(val): + """Convert a value to a string.""" + if val is None: + return "" + if isinstance(val, str): + return val + if isinstance(val, Term): + if val.type == IRI: + return val.iri + elif val.type == LITERAL: + return val.value + elif val.type == BLANK: + return val.id + return str(val) + + +def _to_numeric(val): + """Convert a value to a number.""" + if val is None: + return None + if isinstance(val, (int, float)): + return val + if isinstance(val, Term) and val.type == LITERAL: + try: + if "." in val.value: + return float(val.value) + return int(val.value) + except (ValueError, TypeError): + return None + if isinstance(val, str): + try: + if "." in val: + return float(val) + return int(val) + except (ValueError, TypeError): + return None + return None + + +def _comparable_value(val): + """ + Convert a value to a form suitable for comparison. + Returns a tuple (type, value) for consistent ordering. + """ + if val is None: + return (0, "") + if isinstance(val, bool): + return (1, val) + if isinstance(val, (int, float)): + return (2, val) + if isinstance(val, str): + return (3, val) + if isinstance(val, Term): + if val.type == IRI: + return (4, val.iri) + elif val.type == LITERAL: + # Try numeric comparison for numeric types + num = _to_numeric(val) + if num is not None: + return (2, num) + return (3, val.value) + elif val.type == BLANK: + return (5, val.id) + return (6, str(val)) diff --git a/trustgraph-flow/trustgraph/query/sparql/parser.py b/trustgraph-flow/trustgraph/query/sparql/parser.py new file mode 100644 index 00000000..7de18460 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/parser.py @@ -0,0 +1,139 @@ +""" +SPARQL parser wrapping rdflib's SPARQL 1.1 parser and algebra compiler. +Parses a SPARQL query string into an algebra tree for evaluation. +""" + +import logging + +from rdflib.plugins.sparql import prepareQuery +from rdflib.plugins.sparql.algebra import translateQuery +from rdflib.plugins.sparql.parserutils import CompValue +from rdflib.term import Variable, URIRef, Literal, BNode + +from ... schema import Term, Triple, IRI, LITERAL, BLANK + +logger = logging.getLogger(__name__) + + +class ParseError(Exception): + """Raised when a SPARQL query cannot be parsed.""" + pass + + +class ParsedQuery: + """Result of parsing a SPARQL query string.""" + + def __init__(self, algebra, query_type, variables=None): + self.algebra = algebra + self.query_type = query_type # "select", "ask", "construct", "describe" + self.variables = variables or [] # projected variable names (SELECT) + + +def rdflib_term_to_term(t): + """Convert an rdflib term (URIRef, Literal, BNode) to a schema Term.""" + if isinstance(t, URIRef): + return Term(type=IRI, iri=str(t)) + elif isinstance(t, Literal): + term = Term(type=LITERAL, value=str(t)) + if t.datatype: + term.datatype = str(t.datatype) + if t.language: + term.language = t.language + return term + elif isinstance(t, BNode): + return Term(type=BLANK, id=str(t)) + else: + return Term(type=LITERAL, value=str(t)) + + +def term_to_rdflib(t): + """Convert a schema Term to an rdflib term.""" + if t.type == IRI: + return URIRef(t.iri) + elif t.type == LITERAL: + kwargs = {} + if t.datatype: + kwargs["datatype"] = URIRef(t.datatype) + if t.language: + kwargs["lang"] = t.language + return Literal(t.value, **kwargs) + elif t.type == BLANK: + return BNode(t.id) + else: + return Literal(t.value) + + +def parse_sparql(query_string): + """ + Parse a SPARQL query string into a ParsedQuery. + + Args: + query_string: SPARQL 1.1 query string + + Returns: + ParsedQuery with algebra tree, query type, and projected variables + + Raises: + ParseError: if the query cannot be parsed + """ + try: + prepared = prepareQuery(query_string) + except Exception as e: + raise ParseError(f"SPARQL parse error: {e}") from e + + algebra = prepared.algebra + + # Determine query type and extract variables + query_type = _detect_query_type(algebra) + variables = _extract_variables(algebra, query_type) + + return ParsedQuery( + algebra=algebra, + query_type=query_type, + variables=variables, + ) + + +def _detect_query_type(algebra): + """Detect the SPARQL query type from the algebra root.""" + name = algebra.name + + if name == "SelectQuery": + return "select" + elif name == "AskQuery": + return "ask" + elif name == "ConstructQuery": + return "construct" + elif name == "DescribeQuery": + return "describe" + + # The top-level algebra node may be a modifier (Project, Slice, etc.) + # wrapping the actual query. Check for common patterns. + if name in ("Project", "Distinct", "Reduced", "OrderBy", "Slice"): + return "select" + + logger.warning(f"Unknown algebra root type: {name}, assuming select") + return "select" + + +def _extract_variables(algebra, query_type): + """Extract projected variable names from the algebra.""" + if query_type != "select": + return [] + + # For SELECT queries, the Project node has PV (projected variables) + if hasattr(algebra, "PV"): + return [str(v) for v in algebra.PV] + + # Walk down through modifiers to find Project + node = algebra + while hasattr(node, "p"): + node = node.p + if hasattr(node, "PV"): + return [str(v) for v in node.PV] + + # Fallback: collect all variables from the algebra + if hasattr(algebra, "_vars"): + return [str(v) for v in algebra._vars] + + return [] diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py new file mode 100644 index 00000000..e815540f --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -0,0 +1,230 @@ +""" +SPARQL query service. Accepts SPARQL queries, decomposes them into triple +pattern lookups via the triples query pub/sub interface, performs in-memory +joins/filters/projections, and returns SPARQL result bindings. +""" + +import logging + +from ... schema import SparqlQueryRequest, SparqlQueryResponse +from ... schema import SparqlBinding, Error, Term, Triple +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import TriplesClientSpec + +from . parser import parse_sparql, ParseError +from . algebra import evaluate, EvaluationError + +logger = logging.getLogger(__name__) + +default_ident = "sparql-query" +default_concurrency = 10 + + +class Processor(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) + + super(Processor, self).__init__( + **params | { + "id": id, + "concurrency": concurrency, + } + ) + + self.register_specification( + ConsumerSpec( + name="request", + schema=SparqlQueryRequest, + handler=self.on_message, + concurrency=concurrency, + ) + ) + + self.register_specification( + ProducerSpec( + name="response", + schema=SparqlQueryResponse, + ) + ) + + self.register_specification( + TriplesClientSpec( + request_name="triples-request", + response_name="triples-response", + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + id = msg.properties()["id"] + + logger.debug(f"Handling SPARQL query request {id}...") + + response = await self.execute_sparql(request, flow) + + await flow("response").send(response, properties={"id": id}) + + logger.debug("SPARQL query request completed") + + except Exception as e: + + logger.error( + f"Exception in SPARQL query service: {e}", exc_info=True + ) + + r = SparqlQueryResponse( + error=Error( + type="sparql-query-error", + message=str(e), + ), + ) + + await flow("response").send(r, properties={"id": id}) + + async def execute_sparql(self, request, flow): + """Parse and evaluate a SPARQL query.""" + + # Parse the SPARQL query + try: + parsed = parse_sparql(request.query) + except ParseError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-parse-error", + message=str(e), + ), + ) + + # Get the triples client from the flow + triples_client = flow("triples-request") + + # Evaluate the algebra + try: + solutions = await evaluate( + parsed.algebra, + triples_client, + user=request.user or "trustgraph", + collection=request.collection or "default", + limit=request.limit or 10000, + ) + except EvaluationError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ) + + # Build response based on query type + if parsed.query_type == "select": + return self._build_select_response(parsed, solutions) + elif parsed.query_type == "ask": + return self._build_ask_response(solutions) + elif parsed.query_type == "construct": + return self._build_construct_response(parsed, solutions) + elif parsed.query_type == "describe": + return self._build_describe_response(parsed, solutions) + else: + return SparqlQueryResponse( + error=Error( + type="sparql-unsupported", + message=f"Unsupported query type: {parsed.query_type}", + ), + ) + + def _build_select_response(self, parsed, solutions): + """Build response for SELECT queries.""" + variables = parsed.variables + + bindings = [] + for sol in solutions: + values = [sol.get(v) for v in variables] + bindings.append(SparqlBinding(values=values)) + + return SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=bindings, + ) + + def _build_ask_response(self, solutions): + """Build response for ASK queries.""" + return SparqlQueryResponse( + query_type="ask", + ask_result=len(solutions) > 0, + ) + + def _build_construct_response(self, parsed, solutions): + """Build response for CONSTRUCT queries.""" + # CONSTRUCT template is in the algebra + template = [] + if hasattr(parsed.algebra, "template"): + template = parsed.algebra.template + + triples = [] + seen = set() + + for sol in solutions: + for s_tmpl, p_tmpl, o_tmpl in template: + from rdflib.term import Variable + from . parser import rdflib_term_to_term + + s = self._resolve_construct_term(s_tmpl, sol) + p = self._resolve_construct_term(p_tmpl, sol) + o = self._resolve_construct_term(o_tmpl, sol) + + if s is not None and p is not None and o is not None: + key = ( + s.type, s.iri or s.value, + p.type, p.iri or p.value, + o.type, o.iri or o.value, + ) + if key not in seen: + seen.add(key) + triples.append(Triple(s=s, p=p, o=o)) + + return SparqlQueryResponse( + query_type="construct", + triples=triples, + ) + + def _build_describe_response(self, parsed, solutions): + """Build response for DESCRIBE queries.""" + # DESCRIBE returns all triples about the described resources + # For now, return empty - would need additional triples queries + return SparqlQueryResponse( + query_type="describe", + triples=[], + ) + + def _resolve_construct_term(self, tmpl, solution): + """Resolve a CONSTRUCT template term.""" + from rdflib.term import Variable + from . parser import rdflib_term_to_term + + if isinstance(tmpl, Variable): + return solution.get(str(tmpl)) + else: + return rdflib_term_to_term(tmpl) + + @staticmethod + def add_args(parser): + FlowProcessor.add_args(parser) + + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests ' + f'(default: {default_concurrency})' + ) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/sparql/solutions.py b/trustgraph-flow/trustgraph/query/sparql/solutions.py new file mode 100644 index 00000000..d1ea8373 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/solutions.py @@ -0,0 +1,248 @@ +""" +Solution sequence operations for SPARQL evaluation. + +A solution is a dict mapping variable names (str) to Term values. +A solution sequence is a list of solutions. +""" + +import logging +from collections import defaultdict + +from ... schema import Term, IRI, LITERAL, BLANK + +logger = logging.getLogger(__name__) + + +def _term_key(term): + """Create a hashable key from a Term for join/distinct operations.""" + if term is None: + return None + if term.type == IRI: + return ("i", term.iri) + elif term.type == LITERAL: + return ("l", term.value, term.datatype, term.language) + elif term.type == BLANK: + return ("b", term.id) + else: + return ("?", str(term)) + + +def _solution_key(solution, variables): + """Create a hashable key from a solution for the given variables.""" + return tuple(_term_key(solution.get(v)) for v in variables) + + +def _terms_equal(a, b): + """Check if two Terms are equal.""" + if a is None and b is None: + return True + if a is None or b is None: + return False + return _term_key(a) == _term_key(b) + + +def _compatible(sol_a, sol_b): + """Check if two solutions are compatible (agree on shared variables).""" + shared = set(sol_a.keys()) & set(sol_b.keys()) + return all(_terms_equal(sol_a[v], sol_b[v]) for v in shared) + + +def _merge(sol_a, sol_b): + """Merge two compatible solutions into one.""" + result = dict(sol_a) + result.update(sol_b) + return result + + +def hash_join(left, right): + """ + Inner join two solution sequences on shared variables. + Uses hash join for efficiency. + """ + if not left or not right: + return [] + + left_vars = set() + for sol in left: + left_vars.update(sol.keys()) + + right_vars = set() + for sol in right: + right_vars.update(sol.keys()) + + shared = sorted(left_vars & right_vars) + + if not shared: + # Cross product + return [_merge(l, r) for l in left for r in right] + + # Build hash table on the smaller side + if len(left) <= len(right): + index = defaultdict(list) + for sol in left: + key = _solution_key(sol, shared) + index[key].append(sol) + + results = [] + for sol_r in right: + key = _solution_key(sol_r, shared) + for sol_l in index.get(key, []): + results.append(_merge(sol_l, sol_r)) + return results + else: + index = defaultdict(list) + for sol in right: + key = _solution_key(sol, shared) + index[key].append(sol) + + results = [] + for sol_l in left: + key = _solution_key(sol_l, shared) + for sol_r in index.get(key, []): + results.append(_merge(sol_l, sol_r)) + return results + + +def left_join(left, right, filter_fn=None): + """ + Left outer join (OPTIONAL semantics). + Every left solution is preserved. If it joins with right solutions + (and passes the optional filter), the merged solutions are included. + Otherwise the original left solution is kept. + """ + if not left: + return [] + + if not right: + return list(left) + + right_vars = set() + for sol in right: + right_vars.update(sol.keys()) + + left_vars = set() + for sol in left: + left_vars.update(sol.keys()) + + shared = sorted(left_vars & right_vars) + + # Build hash table on right side + index = defaultdict(list) + for sol in right: + key = _solution_key(sol, shared) if shared else () + index[key].append(sol) + + results = [] + for sol_l in left: + key = _solution_key(sol_l, shared) if shared else () + matches = index.get(key, []) + + matched = False + for sol_r in matches: + merged = _merge(sol_l, sol_r) + if filter_fn is None or filter_fn(merged): + results.append(merged) + matched = True + + if not matched: + results.append(dict(sol_l)) + + return results + + +def union(left, right): + """Union two solution sequences (concatenation).""" + return list(left) + list(right) + + +def project(solutions, variables): + """Keep only the specified variables in each solution.""" + return [ + {v: sol[v] for v in variables if v in sol} + for sol in solutions + ] + + +def distinct(solutions): + """Remove duplicate solutions.""" + seen = set() + results = [] + for sol in solutions: + key = tuple(sorted( + (k, _term_key(v)) for k, v in sol.items() + )) + if key not in seen: + seen.add(key) + results.append(sol) + return results + + +def order_by(solutions, key_fns): + """ + Sort solutions by the given key functions. + + key_fns is a list of (fn, ascending) tuples where fn extracts + a comparable value from a solution. + """ + if not key_fns: + return solutions + + def sort_key(sol): + keys = [] + for fn, ascending in key_fns: + val = fn(sol) + # Convert to comparable form + if val is None: + comparable = ("", "") + elif isinstance(val, Term): + comparable = _term_key(val) + else: + comparable = ("v", str(val)) + keys.append(comparable) + return keys + + # Handle ascending/descending + # For simplicity, sort ascending then reverse individual keys + # This works for single sort keys; for multiple mixed keys we + # need a wrapper + result = sorted(solutions, key=sort_key) + + # If any key is descending, we need a more complex approach. + # Check if all are same direction for the simple case. + if key_fns and all(not asc for _, asc in key_fns): + result.reverse() + elif key_fns and not all(asc for _, asc in key_fns): + # Mixed ascending/descending - use negation wrapper + result = _mixed_sort(solutions, key_fns) + + return result + + +def _mixed_sort(solutions, key_fns): + """Sort with mixed ascending/descending keys.""" + import functools + + def compare(a, b): + for fn, ascending in key_fns: + va = fn(a) + vb = fn(b) + ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "") + kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "") + + if ka < kb: + return -1 if ascending else 1 + elif ka > kb: + return 1 if ascending else -1 + + return 0 + + return sorted(solutions, key=functools.cmp_to_key(compare)) + + +def slice_solutions(solutions, offset=0, limit=None): + """Apply OFFSET and LIMIT to a solution sequence.""" + if offset: + solutions = solutions[offset:] + if limit is not None: + solutions = solutions[:limit] + return solutions From ee65d90fdd91ca72d66ad741f653bf1a206d22d9 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 2 Apr 2026 17:54:07 +0100 Subject: [PATCH 29/37] SPARQL service supports batching/streaming (#755) --- .../trustgraph/api/socket_client.py | 25 +++ .../messaging/translators/sparql_query.py | 6 +- .../schema/services/sparql_query.py | 4 + .../trustgraph/cli/invoke_sparql_query.py | 195 +++++++++--------- .../trustgraph/query/sparql/service.py | 34 ++- 5 files changed, 169 insertions(+), 95 deletions(-) diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 847513d3..9c37a9b1 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -812,6 +812,31 @@ class SocketFlowInstance: else: yield response + def sparql_query_stream( + self, + query: str, + user: str = "trustgraph", + collection: str = "default", + limit: int = 10000, + batch_size: int = 20, + **kwargs: Any + ) -> Iterator[Dict[str, Any]]: + """Execute a SPARQL query with streaming batches.""" + request = { + "query": query, + "user": user, + "collection": collection, + "limit": limit, + "streaming": True, + "batch-size": batch_size, + } + request.update(kwargs) + + for response in self.client._send_request_sync( + "sparql", self.flow_id, request, streaming_raw=True + ): + yield response + def rows_query( self, query: str, diff --git a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py index d1912429..a8b13865 100644 --- a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py @@ -16,6 +16,8 @@ class SparqlQueryRequestTranslator(MessageTranslator): collection=data.get("collection", "default"), query=data.get("query", ""), limit=int(data.get("limit", 10000)), + streaming=data.get("streaming", False), + batch_size=int(data.get("batch-size", 20)), ) def encode(self, obj: SparqlQueryRequest) -> Dict[str, Any]: @@ -24,6 +26,8 @@ class SparqlQueryRequestTranslator(MessageTranslator): "collection": obj.collection, "query": obj.query, "limit": obj.limit, + "streaming": obj.streaming, + "batch-size": obj.batch_size, } @@ -108,4 +112,4 @@ class SparqlQueryResponseTranslator(MessageTranslator): def encode_with_completion( self, obj: SparqlQueryResponse ) -> Tuple[Dict[str, Any], bool]: - return self.encode(obj), True + return self.encode(obj), obj.is_final diff --git a/trustgraph-base/trustgraph/schema/services/sparql_query.py b/trustgraph-base/trustgraph/schema/services/sparql_query.py index 105cc753..62c02c93 100644 --- a/trustgraph-base/trustgraph/schema/services/sparql_query.py +++ b/trustgraph-base/trustgraph/schema/services/sparql_query.py @@ -20,6 +20,8 @@ class SparqlQueryRequest: collection: str = "" query: str = "" # SPARQL query string limit: int = 10000 # Safety limit on results + streaming: bool = False # Enable streaming mode + batch_size: int = 20 # Bindings per batch in streaming mode @dataclass class SparqlQueryResponse: @@ -36,5 +38,7 @@ class SparqlQueryResponse: # For CONSTRUCT/DESCRIBE queries triples: list[Triple] = field(default_factory=list) + is_final: bool = True # False for intermediate batches in streaming + sparql_query_request_queue = queue('sparql-query', cls='request') sparql_query_response_queue = queue('sparql-query', cls='response') diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py index 9547193d..82e48456 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -13,85 +13,20 @@ default_user = 'trustgraph' default_collection = 'default' -def format_select(response, output_format): - """Format SELECT query results.""" - variables = response.get("variables", []) - bindings = response.get("bindings", []) - - if not bindings: - return "No results." - - if output_format == "json": - rows = [] - for binding in bindings: - row = {} - for var, val in zip(variables, binding.get("values", [])): - if val is None: - row[var] = None - elif val.get("t") == "i": - row[var] = val.get("i", "") - elif val.get("t") == "l": - row[var] = val.get("v", "") - else: - row[var] = val.get("v", val.get("i", "")) - rows.append(row) - return json.dumps(rows, indent=2) - - # Table format - col_widths = [len(v) for v in variables] - rows = [] - for binding in bindings: - row = [] - for i, val in enumerate(binding.get("values", [])): - if val is None: - cell = "" - elif val.get("t") == "i": - cell = val.get("i", "") - elif val.get("t") == "l": - cell = val.get("v", "") - else: - cell = val.get("v", val.get("i", "")) - row.append(cell) - if i < len(col_widths): - col_widths[i] = max(col_widths[i], len(cell)) - rows.append(row) - - # Build table - header = " | ".join( - v.ljust(col_widths[i]) for i, v in enumerate(variables) - ) - separator = "-+-".join("-" * w for w in col_widths) - lines = [header, separator] - for row in rows: - line = " | ".join( - cell.ljust(col_widths[i]) if i < len(col_widths) else cell - for i, cell in enumerate(row) - ) - lines.append(line) - return "\n".join(lines) - - -def format_triples(response, output_format): - """Format CONSTRUCT/DESCRIBE results.""" - triples = response.get("triples", []) - - if not triples: - return "No triples." - - if output_format == "json": - return json.dumps(triples, indent=2) - - lines = [] - for t in triples: - s = _term_str(t.get("s")) - p = _term_str(t.get("p")) - o = _term_str(t.get("o")) - lines.append(f"{s} {p} {o} .") - return "\n".join(lines) +def _term_cell(val): + """Extract display string from a wire-format term.""" + if val is None: + return "" + t = val.get("t", "") + if t == "i": + return val.get("i", "") + elif t == "l": + return val.get("v", "") + return val.get("v", val.get("i", "")) def _term_str(val): - """Convert a wire-format term to a display string.""" + """Convert a wire-format term to a Turtle-style display string.""" if val is None: return "?" t = val.get("t", "") @@ -110,27 +45,93 @@ def _term_str(val): def sparql_query(url, token, flow_id, query, user, collection, limit, - output_format): + batch_size, output_format): - api = Api(url=url, token=token).flow().id(flow_id) + socket = Api(url=url, token=token).socket() + flow = socket.flow(flow_id) - resp = api.sparql_query( - query=query, - user=user, - collection=collection, - limit=limit, - ) + variables = None + all_rows = [] - query_type = resp.get("query-type", "select") + try: - if query_type == "select": - print(format_select(resp, output_format)) - elif query_type == "ask": - print("true" if resp.get("ask-result") else "false") - elif query_type in ("construct", "describe"): - print(format_triples(resp, output_format)) - else: - print(json.dumps(resp, indent=2)) + for response in flow.sparql_query_stream( + query=query, + user=user, + collection=collection, + limit=limit, + batch_size=batch_size, + ): + query_type = response.get("query-type", "select") + + # ASK queries - just print and return + if query_type == "ask": + print("true" if response.get("ask-result") else "false") + return + + # CONSTRUCT/DESCRIBE - print triples + if query_type in ("construct", "describe"): + triples = response.get("triples", []) + if not triples: + print("No triples.") + elif output_format == "json": + print(json.dumps(triples, indent=2)) + else: + for t in triples: + s = _term_str(t.get("s")) + p = _term_str(t.get("p")) + o = _term_str(t.get("o")) + print(f"{s} {p} {o} .") + return + + # SELECT - accumulate bindings across batches + if variables is None: + variables = response.get("variables", []) + + bindings = response.get("bindings", []) + for binding in bindings: + values = binding.get("values", []) + all_rows.append([_term_cell(v) for v in values]) + + # Output SELECT results + if variables is None: + print("No results.") + return + + if not all_rows: + print("No results.") + return + + if output_format == "json": + rows = [] + for row in all_rows: + rows.append({ + var: cell for var, cell in zip(variables, row) + }) + print(json.dumps(rows, indent=2)) + else: + # Table format + col_widths = [len(v) for v in variables] + for row in all_rows: + for i, cell in enumerate(row): + if i < len(col_widths): + col_widths[i] = max(col_widths[i], len(cell)) + + header = " | ".join( + v.ljust(col_widths[i]) for i, v in enumerate(variables) + ) + separator = "-+-".join("-" * w for w in col_widths) + print(header) + print(separator) + for row in all_rows: + line = " | ".join( + cell.ljust(col_widths[i]) if i < len(col_widths) else cell + for i, cell in enumerate(row) + ) + print(line) + + finally: + socket.close() def main(): @@ -187,6 +188,13 @@ def main(): help='Result limit (default: 10000)', ) + parser.add_argument( + '-b', '--batch-size', + type=int, + default=20, + help='Streaming batch size (default: 20)', + ) + parser.add_argument( '--format', choices=['table', 'json'], @@ -218,6 +226,7 @@ def main(): user=args.user, collection=args.collection, limit=args.limit, + batch_size=args.batch_size, output_format=args.format, ) diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index e815540f..38488032 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -68,7 +68,12 @@ class Processor(FlowProcessor): response = await self.execute_sparql(request, flow) - await flow("response").send(response, properties={"id": id}) + if request.streaming and response.query_type == "select": + await self.send_streaming(response, flow, id, request) + else: + await flow("response").send( + response, properties={"id": id} + ) logger.debug("SPARQL query request completed") @@ -87,6 +92,33 @@ class Processor(FlowProcessor): await flow("response").send(r, properties={"id": id}) + async def send_streaming(self, response, flow, id, request): + """Send SELECT results in batches.""" + + bindings = response.bindings + batch_size = request.batch_size if request.batch_size > 0 else 20 + + for i in range(0, len(bindings), batch_size): + batch = bindings[i:i + batch_size] + is_final = (i + batch_size >= len(bindings)) + r = SparqlQueryResponse( + query_type=response.query_type, + variables=response.variables, + bindings=batch, + is_final=is_final, + ) + await flow("response").send(r, properties={"id": id}) + + # Handle empty results + if len(bindings) == 0: + r = SparqlQueryResponse( + query_type=response.query_type, + variables=response.variables, + bindings=[], + is_final=True, + ) + await flow("response").send(r, properties={"id": id}) + async def execute_sparql(self, request, flow): """Parse and evaluate a SPARQL query.""" From 10a931f04c15762417a3af7ecf999d1a5172f8a5 Mon Sep 17 00:00:00 2001 From: Alex Jenkins Date: Mon, 6 Apr 2026 10:10:14 +0000 Subject: [PATCH 30/37] Feat: Auto-pull missing Ollama models (#757) * fix deadlink in readme Signed-off-by: Jenkins, Kenneth Alexander * feat: Auto-pull Ollama models Signed-off-by: Jenkins, Kenneth Alexander * fix: Restore namespace __init__.py files for package resolution Signed-off-by: Jenkins, Kenneth Alexander * fix CI Signed-off-by: Jenkins, Kenneth Alexander --- .../test_ollama_processor.py | 2 +- .../trustgraph/embeddings/ollama/processor.py | 28 ++++++++++++++++ .../model/text_completion/ollama/llm.py | 33 +++++++++++++++++-- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_text_completion/test_ollama_processor.py b/tests/unit/test_text_completion/test_ollama_processor.py index 0bf5e0ab..69baf85f 100644 --- a/tests/unit/test_text_completion/test_ollama_processor.py +++ b/tests/unit/test_text_completion/test_ollama_processor.py @@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - assert processor.default_model == 'gemma2:9b' # default_model + assert processor.default_model == 'granite4:350m' # default_model # Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env) mock_client_class.assert_called_once() diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index a65b4ff7..c63db33c 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -7,6 +7,9 @@ from ... base import EmbeddingsService from ollama import Client import os +import logging + +logger = logging.getLogger(__name__) default_ident = "embeddings" @@ -29,6 +32,28 @@ class Processor(EmbeddingsService): self.client = Client(host=ollama) self.default_model = model + self._checked_models = set() + + def _ensure_model(self, model_name): + """Check if model exists locally, pull it if not.""" + if model_name in self._checked_models: + return + + try: + self.client.show(model_name) + self._checked_models.add(model_name) + except Exception as e: + status_code = getattr(e, 'status_code', None) + if status_code == 404 or "not found" in str(e).lower(): + logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") + try: + self.client.pull(model_name) + self._checked_models.add(model_name) + logger.info(f"Successfully pulled Ollama model '{model_name}'.") + except Exception as pull_e: + logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}") + else: + logger.warning(f"Failed to check Ollama model '{model_name}': {e}") async def on_embeddings(self, texts, model=None): @@ -37,6 +62,9 @@ class Processor(EmbeddingsService): use_model = model or self.default_model + # Ensure the model exists/is pulled + self._ensure_model(use_model) + # Ollama handles batch input efficiently embeds = self.client.embed( model = use_model, diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index 3616e428..f6c5dcb8 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -16,7 +16,7 @@ from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" -default_model = 'gemma2:9b' +default_model = 'granite4:350m' default_temperature = 0.0 default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') @@ -39,11 +39,36 @@ class Processor(LlmService): self.default_model = model self.temperature = temperature self.llm = Client(host=ollama) + self._checked_models = set() + + def _ensure_model(self, model_name): + """Check if model exists locally, pull it if not.""" + if model_name in self._checked_models: + return + + try: + self.llm.show(model_name) + self._checked_models.add(model_name) + except Exception as e: + status_code = getattr(e, 'status_code', None) + if status_code == 404 or "not found" in str(e).lower(): + logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...") + try: + self.llm.pull(model_name) + self._checked_models.add(model_name) + logger.info(f"Successfully pulled Ollama model '{model_name}'.") + except Exception as pull_e: + logger.error(f"Failed to pull Ollama model '{model_name}': {pull_e}") + else: + logger.warning(f"Failed to check Ollama model '{model_name}': {e}") async def generate_content(self, system, prompt, model=None, temperature=None): # Use provided model or fall back to default model_name = model or self.default_model + + # Ensure the model exists/is pulled + self._ensure_model(model_name) # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature @@ -86,6 +111,10 @@ class Processor(LlmService): async def generate_content_stream(self, system, prompt, model=None, temperature=None): """Stream content generation from Ollama""" model_name = model or self.default_model + + # Ensure the model exists/is pulled + self._ensure_model(model_name) + effective_temperature = temperature if temperature is not None else self.temperature logger.debug(f"Using model (streaming): {model_name}") @@ -142,7 +171,7 @@ class Processor(LlmService): parser.add_argument( '-m', '--model', - default="gemma2", + default="granite4:350m", help=f'LLM model (default: {default_model})' ) From d4723566cb75cf1334d5d19b116287bbf08fbd17 Mon Sep 17 00:00:00 2001 From: "V.Sreeram" Date: Mon, 6 Apr 2026 15:43:59 +0530 Subject: [PATCH 31/37] fix: prevent duplicate dispatcher creation race condition in invoke_global_service (#715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: prevent duplicate dispatcher creation race condition in invoke_global_service Concurrent coroutines could all pass the `if key in self.dispatchers` check before any of them wrote the result back, because `await dispatcher.start()` yields to the event loop. This caused multiple Pulsar consumers to be created on the same shared subscription, distributing responses round-robin and dropping ~2/3 of them — manifesting as a permanent spinner in the Workbench UI. Apply a double-checked asyncio.Lock in both `invoke_global_service` and `invoke_flow_service` so only one dispatcher is ever created per service key. * test: add concurrent-dispatch tests for race condition fix Add asyncio.gather-based tests that verify invoke_global_service and invoke_flow_service create exactly one dispatcher under concurrent calls, preventing the duplicate Pulsar consumer bug. --- .../test_gateway/test_dispatch_manager.py | 92 +++++++++++++++++- .../trustgraph/gateway/dispatch/manager.py | 94 +++++++++---------- 2 files changed, 133 insertions(+), 53 deletions(-) diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index 33f1229d..83969fdd 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -49,7 +49,8 @@ class TestDispatcherManager: assert manager.prefix == "api-gateway" # default prefix assert manager.flows == {} assert manager.dispatchers == {} - + assert isinstance(manager.dispatcher_lock, asyncio.Lock) + # Verify manager was added as handler to config receiver mock_config_receiver.add_handler.assert_called_once_with(manager) @@ -543,18 +544,99 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver) - + # Setup test flow with interface but unsupported kind manager.flows["test_flow"] = { "interfaces": { "invalid-kind": {"request": "req", "response": "resp"} } } - + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \ patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers: mock_rr_dispatchers.__contains__.return_value = False mock_sender_dispatchers.__contains__.return_value = False - + with pytest.raises(RuntimeError, match="Invalid kind"): - await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") \ No newline at end of file + await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind") + + @pytest.mark.asyncio + async def test_invoke_global_service_concurrent_calls_create_single_dispatcher(self): + """Concurrent calls for the same service must create exactly one dispatcher. + + Before the fix, await dispatcher.start() yielded to the event loop and + multiple coroutines could all pass the 'key not in self.dispatchers' check + before any of them wrote the result back, creating duplicate Pulsar consumers. + """ + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver) + + async def slow_start(): + # Yield to the event loop so other coroutines get a chance to run, + # reproducing the window that caused the original race condition. + await asyncio.sleep(0) + + with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: + mock_dispatcher_class = Mock() + mock_dispatcher = Mock() + mock_dispatcher.start = AsyncMock(side_effect=slow_start) + mock_dispatcher.process = AsyncMock(return_value="result") + mock_dispatcher_class.return_value = mock_dispatcher + mock_dispatchers.__getitem__.return_value = mock_dispatcher_class + + results = await asyncio.gather(*[ + manager.invoke_global_service("data", "responder", "config") + for _ in range(5) + ]) + + assert mock_dispatcher_class.call_count == 1, ( + "Dispatcher class instantiated more than once — duplicate consumer bug" + ) + assert mock_dispatcher.start.call_count == 1 + assert manager.dispatchers[(None, "config")] is mock_dispatcher + assert all(r == "result" for r in results) + + @pytest.mark.asyncio + async def test_invoke_flow_service_concurrent_calls_create_single_dispatcher(self): + """Concurrent calls for the same flow+kind must create exactly one dispatcher. + + invoke_flow_service has the same check-then-create pattern as + invoke_global_service and is protected by the same dispatcher_lock. + """ + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver) + + manager.flows["test_flow"] = { + "interfaces": { + "agent": { + "request": "agent_request_queue", + "response": "agent_response_queue", + } + } + } + + async def slow_start(): + await asyncio.sleep(0) + + with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers: + mock_dispatcher_class = Mock() + mock_dispatcher = Mock() + mock_dispatcher.start = AsyncMock(side_effect=slow_start) + mock_dispatcher.process = AsyncMock(return_value="result") + mock_dispatcher_class.return_value = mock_dispatcher + mock_rr_dispatchers.__getitem__.return_value = mock_dispatcher_class + mock_rr_dispatchers.__contains__.return_value = True + + results = await asyncio.gather(*[ + manager.invoke_flow_service("data", "responder", "test_flow", "agent") + for _ in range(5) + ]) + + assert mock_dispatcher_class.call_count == 1, ( + "Dispatcher class instantiated more than once — duplicate consumer bug" + ) + assert mock_dispatcher.start.call_count == 1 + assert manager.dispatchers[("test_flow", "agent")] is mock_dispatcher + assert all(r == "result" for r in results) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index a4bf8de9..ef3d5507 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -118,6 +118,7 @@ class DispatcherManager: self.flows = {} self.dispatchers = {} + self.dispatcher_lock = asyncio.Lock() async def start_flow(self, id, flow): logger.info(f"Starting flow {id}") @@ -165,30 +166,28 @@ class DispatcherManager: key = (None, kind) - if key in self.dispatchers: - return await self.dispatchers[key].process(data, responder) + if key not in self.dispatchers: + async with self.dispatcher_lock: + if key not in self.dispatchers: + request_queue = None + response_queue = None + if kind in self.queue_overrides: + request_queue = self.queue_overrides[kind].get("request") + response_queue = self.queue_overrides[kind].get("response") - # Get queue overrides if specified for this service - request_queue = None - response_queue = None - if kind in self.queue_overrides: - request_queue = self.queue_overrides[kind].get("request") - response_queue = self.queue_overrides[kind].get("response") + dispatcher = global_dispatchers[kind]( + backend = self.backend, + timeout = 120, + consumer = f"{self.prefix}-{kind}-request", + subscriber = f"{self.prefix}-{kind}-request", + request_queue = request_queue, + response_queue = response_queue, + ) - dispatcher = global_dispatchers[kind]( - backend = self.backend, - timeout = 120, - consumer = f"{self.prefix}-{kind}-request", - subscriber = f"{self.prefix}-{kind}-request", - request_queue = request_queue, - response_queue = response_queue, - ) + await dispatcher.start() + self.dispatchers[key] = dispatcher - await dispatcher.start() - - self.dispatchers[key] = dispatcher - - return await dispatcher.process(data, responder) + return await self.dispatchers[key].process(data, responder) def dispatch_flow_import(self): return self.process_flow_import @@ -299,36 +298,35 @@ class DispatcherManager: key = (flow, kind) - if key in self.dispatchers: - return await self.dispatchers[key].process(data, responder) + if key not in self.dispatchers: + async with self.dispatcher_lock: + if key not in self.dispatchers: + intf_defs = self.flows[flow]["interfaces"] - intf_defs = self.flows[flow]["interfaces"] + if kind not in intf_defs: + raise RuntimeError("This kind not supported by flow") - if kind not in intf_defs: - raise RuntimeError("This kind not supported by flow") + qconfig = intf_defs[kind] - qconfig = intf_defs[kind] + if kind in request_response_dispatchers: + dispatcher = request_response_dispatchers[kind]( + backend = self.backend, + request_queue = qconfig["request"], + response_queue = qconfig["response"], + timeout = 120, + consumer = f"{self.prefix}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{flow}-{kind}-request", + ) + elif kind in sender_dispatchers: + dispatcher = sender_dispatchers[kind]( + backend = self.backend, + queue = qconfig, + ) + else: + raise RuntimeError("Invalid kind") - if kind in request_response_dispatchers: - dispatcher = request_response_dispatchers[kind]( - backend = self.backend, - request_queue = qconfig["request"], - response_queue = qconfig["response"], - timeout = 120, - consumer = f"{self.prefix}-{flow}-{kind}-request", - subscriber = f"{self.prefix}-{flow}-{kind}-request", - ) - elif kind in sender_dispatchers: - dispatcher = sender_dispatchers[kind]( - backend = self.backend, - queue = qconfig, - ) - else: - raise RuntimeError("Invalid kind") - - await dispatcher.start() + await dispatcher.start() + self.dispatchers[key] = dispatcher - self.dispatchers[key] = dispatcher - - return await dispatcher.process(data, responder) + return await self.dispatchers[key].process(data, responder) From 4acd853023d1c0d5544c0d5daa60ef04c63f8c17 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 6 Apr 2026 16:57:27 +0100 Subject: [PATCH 32/37] Config push notify pattern: replace stateful pub/sub with signal+ fetch (#760) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the config push mechanism that broadcast the full config blob on a 'state' class pub/sub queue with a lightweight notify signal containing only the version number and affected config types. Processors fetch the full config via request/response from the config service when notified. This eliminates the need for the pub/sub 'state' queue class and stateful pub/sub services entirely. The config push queue moves from 'state' to 'flow' class — a simple transient signal rather than a retained message. This solves the RabbitMQ late-subscriber problem where restarting processes never received the current config because their fresh queue had no historical messages. Key changes: - ConfigPush schema: config dict replaced with types list - Subscribe-then-fetch startup with retry: processors subscribe to notify queue, fetch config via request/response, then process buffered notifies with version comparison to avoid race conditions - register_config_handler() accepts optional types parameter so handlers only fire when their config types change - Short-lived config request/response clients to avoid subscriber contention on non-persistent response topics - Config service passes affected types through put/delete/flow operations - Gateway ConfigReceiver rewritten with same notify pattern and retry loop Tests updated New tests: - register_config_handler: without types, with types, multiple types, multiple handlers - on_config_notify: old/same version skipped, irrelevant types skipped (version still updated), relevant type triggers fetch, handler without types always called, mixed handler filtering, empty types invokes all, fetch failure handled gracefully - fetch_config: returns config+version, raises on error response, stops client even on exception - fetch_and_apply_config: applies to all handlers on startup, retries on failure --- dev-tools/library_client.py | 237 ++++++++ docs/tech-specs/config-push-poke.md | 282 ++++++++++ .../test_base/test_async_processor_config.py | 323 +++++++++++ tests/unit/test_base/test_flow_processor.py | 4 +- .../unit/test_gateway/test_config_receiver.py | 509 ++++++++---------- tests/unit/test_pubsub/test_queue_naming.py | 2 +- .../trustgraph/base/async_processor.py | 172 +++++- .../trustgraph/base/flow_processor.py | 4 +- .../trustgraph/schema/services/config.py | 4 +- .../trustgraph/agent/mcp_tool/service.py | 2 +- .../trustgraph/config/service/config.py | 19 +- .../trustgraph/config/service/flow.py | 12 +- .../trustgraph/config/service/service.py | 13 +- trustgraph-flow/trustgraph/cores/service.py | 2 +- .../embeddings/row_embeddings/embeddings.py | 4 +- .../trustgraph/extract/kg/agent/extract.py | 2 +- .../trustgraph/extract/kg/ontology/extract.py | 2 +- .../trustgraph/extract/kg/rows/processor.py | 2 +- .../trustgraph/gateway/config/receiver.py | 222 ++++++-- .../trustgraph/librarian/service.py | 2 +- .../trustgraph/metering/counter.py | 2 +- .../trustgraph/prompt/template/service.py | 2 +- .../query/rows/cassandra/service.py | 2 +- .../trustgraph/retrieval/nlp_query/service.py | 2 +- .../retrieval/structured_diag/service.py | 2 +- .../storage/doc_embeddings/milvus/write.py | 2 +- .../storage/doc_embeddings/pinecone/write.py | 2 +- .../storage/doc_embeddings/qdrant/write.py | 2 +- .../storage/graph_embeddings/milvus/write.py | 2 +- .../graph_embeddings/pinecone/write.py | 2 +- .../storage/graph_embeddings/qdrant/write.py | 2 +- .../storage/row_embeddings/qdrant/write.py | 2 +- .../storage/rows/cassandra/write.py | 4 +- .../storage/triples/cassandra/write.py | 2 +- .../storage/triples/falkordb/write.py | 2 +- .../storage/triples/memgraph/write.py | 2 +- .../trustgraph/storage/triples/neo4j/write.py | 2 +- 37 files changed, 1449 insertions(+), 406 deletions(-) create mode 100644 dev-tools/library_client.py create mode 100644 docs/tech-specs/config-push-poke.md create mode 100644 tests/unit/test_base/test_async_processor_config.py diff --git a/dev-tools/library_client.py b/dev-tools/library_client.py new file mode 100644 index 00000000..ae9d6857 --- /dev/null +++ b/dev-tools/library_client.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 + +""" +Client utility for browsing and loading documents from the TrustGraph +public document library. + +Usage: + python library_client.py list + python library_client.py search + python library_client.py load-all + python library_client.py load-doc + python library_client.py load-match +""" + +import json +import urllib.request +import sys +import os +import argparse + +from trustgraph.api import Api +from trustgraph.api.types import Uri, Literal, Triple + +BUCKET_URL = "https://storage.googleapis.com/trustgraph-library" +INDEX_URL = f"{BUCKET_URL}/index.json" + +default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) + + +def fetch_index(): + with urllib.request.urlopen(INDEX_URL) as resp: + return json.loads(resp.read()) + + +def fetch_document_metadata(doc_id): + url = f"{BUCKET_URL}/{doc_id}.json" + with urllib.request.urlopen(url) as resp: + return json.loads(resp.read()) + + +def fetch_document_content(doc_id): + url = f"{BUCKET_URL}/{doc_id}.epub" + with urllib.request.urlopen(url) as resp: + return resp.read() + + +def search_index(index, query): + query = query.lower() + results = [] + for doc in index: + title = doc.get("title", "").lower() + comments = doc.get("comments", "").lower() + tags = [t.lower() for t in doc.get("tags", [])] + if (query in title or query in comments or + any(query in t for t in tags)): + results.append(doc) + return results + + +def print_index(index): + if not index: + return + + # Calculate column widths + id_width = max(len(str(doc.get("id", ""))) for doc in index) + title_width = max(len(doc.get("title", "")) for doc in index) + + # Cap title width for readability + title_width = min(title_width, 60) + id_width = max(id_width, 2) + + try: + term_width = os.get_terminal_size().columns + except OSError: + term_width = 120 + + tags_width = max(term_width - id_width - title_width - 6, 20) + + header = f"{'ID':<{id_width}} {'Title':<{title_width}} {'Tags':<{tags_width}}" + print(header) + print("-" * len(header)) + + for doc in index: + eid = str(doc.get("id", "")) + title = doc.get("title", "") + if len(title) > title_width: + title = title[:title_width - 3] + "..." + tags = ", ".join(doc.get("tags", [])) + if len(tags) > tags_width: + tags = tags[:tags_width - 3] + "..." + print(f"{eid:<{id_width}} {title:<{title_width}} {tags}") + + +def convert_value(v): + """Convert a JSON triple value to a Uri or Literal.""" + if v["type"] == "uri": + return Uri(v["value"]) + else: + return Literal(v["value"]) + + +def convert_metadata(metadata_json): + """Convert JSON metadata triples to Triple objects.""" + triples = [] + for t in metadata_json: + triples.append(Triple( + s=convert_value(t["s"]), + p=convert_value(t["p"]), + o=convert_value(t["o"]), + )) + return triples + + +def load_document(api, user, doc_entry): + """Fetch metadata and content for a document, then load into TrustGraph.""" + doc_id = doc_entry["id"] + title = doc_entry["title"] + + print(f" [{doc_id}] {title}") + + print(f" fetching metadata...") + doc_json = fetch_document_metadata(doc_id) + doc = doc_json[0] + + print(f" fetching content...") + content = fetch_document_content(doc_id) + + print(f" loading into TrustGraph ({len(content) // 1024}KB)...") + metadata = convert_metadata(doc["metadata"]) + + api.add_document( + id=doc["id"], + metadata=metadata, + user=user, + kind=doc["kind"], + title=doc["title"], + comments=doc["comments"], + tags=doc["tags"], + document=content, + ) + + print(f" done.") + + +def load_documents(api, user, docs): + """Load a list of documents.""" + print(f"Loading {len(docs)} document(s)...\n") + for doc in docs: + try: + load_document(api, user, doc) + except Exception as e: + print(f" FAILED: {e}", file=sys.stderr) + print() + print("Complete.") + + +def main(): + parser = argparse.ArgumentParser( + description="Browse and load documents from the TrustGraph public document library.", + ) + + parser.add_argument( + "-u", "--url", default=default_url, + help=f"TrustGraph API URL (default: {default_url})", + ) + parser.add_argument( + "-U", "--user", default=default_user, + help=f"User ID (default: {default_user})", + ) + parser.add_argument( + "-t", "--token", default=default_token, + help="Authentication token (default: $TRUSTGRAPH_TOKEN)", + ) + + sub = parser.add_subparsers(dest="command") + + sub.add_parser("list", help="List all documents") + + search_parser = sub.add_parser("search", help="Search documents") + search_parser.add_argument("query", help="Text to search for") + + sub.add_parser("load-all", help="Load all documents into TrustGraph") + + load_doc_parser = sub.add_parser("load-doc", help="Load a document by ID") + load_doc_parser.add_argument("id", help="Document ID (ebook number)") + + load_match_parser = sub.add_parser( + "load-match", help="Load all documents matching a search term", + ) + load_match_parser.add_argument("query", help="Text to search for") + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + sys.exit(1) + + index = fetch_index() + + if args.command in ("list", "search"): + if args.command == "list": + print_index(index) + else: + results = search_index(index, args.query) + if results: + print_index(results) + else: + print("No matches found.", file=sys.stderr) + sys.exit(1) + return + + # Load commands need the API + api = Api(args.url, token=args.token).library() + + if args.command == "load-all": + load_documents(api, args.user, index) + + elif args.command == "load-doc": + matches = [d for d in index if str(d.get("id")) == args.id] + if not matches: + print(f"No document with ID '{args.id}' found.", file=sys.stderr) + sys.exit(1) + load_documents(api, args.user, matches) + + elif args.command == "load-match": + results = search_index(index, args.query) + if results: + load_documents(api, args.user, results) + else: + print("No matches found.", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/docs/tech-specs/config-push-poke.md b/docs/tech-specs/config-push-poke.md new file mode 100644 index 00000000..4273e46d --- /dev/null +++ b/docs/tech-specs/config-push-poke.md @@ -0,0 +1,282 @@ +# Config Push "Notify" Pattern Technical Specification + +## Overview + +Replace the current config push mechanism — which broadcasts the full config +blob on a `state` class queue — with a lightweight "notify" notification +containing only the version number and affected types. Processors that care +about those types fetch the full config via the existing request/response +interface. + +This solves the RabbitMQ late-subscriber problem: when a process restarts, +its fresh queue has no historical messages, so it never receives the current +config state. With the notify pattern, the push queue is only a signal — the +source of truth is the config service's request/response API, which is +always available. + +## Problem + +On Pulsar, `state` class queues are persistent topics. A new subscriber +with `InitialPosition.Earliest` reads from message 0 and receives the +last config push. On RabbitMQ, each subscriber gets a fresh per-subscriber +queue (named with a new UUID). Messages published before the queue existed +are gone. A restarting processor never gets the current config. + +## Design + +### The Notify Message + +The `ConfigPush` schema changes from carrying the full config to carrying +just a version number and the list of affected config types: + +```python +@dataclass +class ConfigPush: + version: int = 0 + types: list[str] = field(default_factory=list) +``` + +When the config service handles a `put` or `delete`, it knows which types +were affected (from the request's `values[].type` or `keys[].type`). It +includes those in the notify. On startup, the config service sends a notify +with an empty types list (meaning "everything"). + +### Subscribe-then-Fetch Startup (No Race Condition) + +The critical ordering to avoid missing an update: + +1. **Subscribe** to the config push queue. Buffer incoming notify messages. +2. **Fetch** the full config via request/response (`operation: "config"`). + This returns the config dict and a version number. +3. **Apply** the fetched config to all registered handlers. +4. **Process** buffered notifys. For any notify with `version > fetched_version`, + re-fetch and re-apply. Discard notifys with `version <= fetched_version`. +5. **Enter steady state**. Process future notifys as they arrive. + +This is safe because: +- If an update happens before the subscription, the fetch picks it up. +- If an update happens between subscribe and fetch, it's in the buffer. +- If an update happens after the fetch, it arrives on the queue normally. +- Version comparison ensures no duplicate processing. + +### Processor API + +The current API requires processors to understand the full config dict +structure. The new API should be cleaner — processors declare which config +types they care about and provide a handler that receives only the relevant +config subset. + +#### Current API + +```python +# In processor __init__: +self.register_config_handler(self.on_configure_flows) + +# Handler receives the entire config dict: +async def on_configure_flows(self, config, version): + if "active-flow" not in config: + return + if self.id in config["active-flow"]: + flow_config = json.loads(config["active-flow"][self.id]) + # ... +``` + +#### New API + +```python +# In processor __init__: +self.register_config_handler( + handler=self.on_configure_flows, + types=["active-flow"], +) + +# Handler receives only the relevant config subset, same signature: +async def on_configure_flows(self, config, version): + # config still contains the full dict, but handler is only called + # when "active-flow" type changes (or on startup) + if "active-flow" not in config: + return + # ... +``` + +The `types` parameter is optional. If omitted, the handler is called for +every config change (backward compatible). If specified, the handler is +only invoked when the notify's `types` list intersects with the handler's +types, or on startup (empty types list = everything). + +#### Internal Registration Structure + +```python +# In AsyncProcessor: +def register_config_handler(self, handler, types=None): + self.config_handlers.append({ + "handler": handler, + "types": set(types) if types else None, # None = all types + }) +``` + +#### Notify Processing Logic + +```python +async def on_config_notify(self, message, consumer, flow): + notify_version = message.value().version + notify_types = set(message.value().types) + + # Skip if we already have this version or newer + if notify_version <= self.config_version: + return + + # Fetch full config from config service + config, version = await self.config_client.config() + self.config_version = version + + # Determine which handlers to invoke + for entry in self.config_handlers: + handler_types = entry["types"] + if handler_types is None: + # Handler cares about everything + await entry["handler"](config, version) + elif not notify_types or notify_types & handler_types: + # notify_types empty = startup (invoke all), + # or intersection with handler's types + await entry["handler"](config, version) +``` + +### Config Service Changes + +#### Push Method + +The `push()` method changes to send only version + types: + +```python +async def push(self, types=None): + version = await self.config.get_version() + resp = ConfigPush( + version=version, + types=types or [], + ) + await self.config_push_producer.send(resp) +``` + +#### Put/Delete Handlers + +Extract affected types and pass to push: + +```python +async def handle_put(self, v): + types = list(set(k.type for k in v.values)) + for k in v.values: + await self.table_store.put_config(k.type, k.key, k.value) + await self.inc_version() + await self.push(types=types) + +async def handle_delete(self, v): + types = list(set(k.type for k in v.keys)) + for k in v.keys: + await self.table_store.delete_key(k.type, k.key) + await self.inc_version() + await self.push(types=types) +``` + +#### Queue Class Change + +The config push queue changes from `state` class to `flow` class. The push +is now a transient signal — the source of truth is the config service's +request/response API, not the queue. `flow` class is persistent (survives +broker restarts) but doesn't require last-message retention, which was the +root cause of the RabbitMQ problem. + +```python +config_push_queue = queue('config', cls='flow') # was cls='state' +``` + +#### Startup Push + +On startup, the config service sends a notify with empty types list +(signalling "everything changed"): + +```python +async def start(self): + await self.push(types=[]) # Empty = all types + await self.config_request_consumer.start() +``` + +### AsyncProcessor Changes + +The `AsyncProcessor` needs a config request/response client alongside the +push consumer. The startup sequence becomes: + +```python +async def start(self): + # 1. Start the push consumer (begins buffering notifys) + await self.config_sub_task.start() + + # 2. Fetch current config via request/response + config, version = await self.config_client.config() + self.config_version = version + + # 3. Apply to all handlers (startup = all handlers invoked) + for entry in self.config_handlers: + await entry["handler"](config, version) + + # 4. Buffered notifys are now processed by on_config_notify, + # which skips versions <= self.config_version +``` + +The config client needs to be created in `__init__` using the existing +request/response queue infrastructure. The `ConfigClient` from +`trustgraph.clients.config_client` already exists but uses a synchronous +blocking pattern. An async variant or integration with the processor's +pub/sub backend is needed. + +### Existing Config Handler Types + +For reference, the config types currently used by handlers: + +| Handler | Type(s) | Used By | +|---------|---------|---------| +| `on_configure_flows` | `active-flow` | All FlowProcessor subclasses | +| `on_collection_config` | `collection` | Storage services (triples, embeddings, rows) | +| `on_prompt_config` | `prompt` | Prompt template service, agent extract | +| `on_schema_config` | `schema` | Rows storage, row embeddings, NLP query, structured diag | +| `on_cost_config` | `token-costs` | Metering service | +| `on_ontology_config` | `ontology` | Ontology extraction | +| `on_librarian_config` | `librarian` | Librarian service | +| `on_mcp_config` | `mcp-tool` | MCP tool service | +| `on_knowledge_config` | `kg-core` | Cores service | + +## Implementation Order + +1. **Update ConfigPush schema** — change `config` field to `types` field. + +2. **Update config service** — modify `push()` to send version + types. + Modify `handle_put`/`handle_delete` to extract affected types. + +3. **Add async config query to AsyncProcessor** — create a + request/response client for config queries within the processor's + event loop. + +4. **Implement subscribe-then-fetch startup** — reorder + `AsyncProcessor.start()` to subscribe first, then fetch, then + process buffered notifys with version comparison. + +5. **Update register_config_handler** — add optional `types` parameter. + Update `on_config_notify` to filter by type intersection. + +6. **Update existing handlers** — add `types` parameter to all + `register_config_handler` calls across the codebase. + +7. **Backward compatibility** — handlers without `types` parameter + continue to work (invoked for all changes). + +## Risks + +- **Thundering herd**: if many processors restart simultaneously, they + all hit the config service API at once. Mitigated by the config service + already being designed for request/response load, and the number of + processors being small (tens, not thousands). + +- **Config service availability**: processors now depend on the config + service being up at startup, not just having received a push. This is + already the case in practice — without config, processors can't do + anything useful. diff --git a/tests/unit/test_base/test_async_processor_config.py b/tests/unit/test_base/test_async_processor_config.py new file mode 100644 index 00000000..f1a83fef --- /dev/null +++ b/tests/unit/test_base/test_async_processor_config.py @@ -0,0 +1,323 @@ +""" +Tests for AsyncProcessor config notify pattern: +- register_config_handler with types filtering +- on_config_notify version comparison and type matching +- fetch_config with short-lived client +- fetch_and_apply_config retry logic +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch, Mock +from trustgraph.schema import Term, IRI, LITERAL + + +# Patch heavy dependencies before importing AsyncProcessor +@pytest.fixture +def processor(): + """Create an AsyncProcessor with mocked dependencies.""" + with patch('trustgraph.base.async_processor.get_pubsub') as mock_pubsub, \ + patch('trustgraph.base.async_processor.Consumer') as mock_consumer, \ + patch('trustgraph.base.async_processor.ProcessorMetrics') as mock_pm, \ + patch('trustgraph.base.async_processor.ConsumerMetrics') as mock_cm: + + mock_pubsub.return_value = MagicMock() + mock_consumer.return_value = MagicMock() + mock_pm.return_value = MagicMock() + mock_cm.return_value = MagicMock() + + from trustgraph.base.async_processor import AsyncProcessor + p = AsyncProcessor( + id="test-processor", + taskgroup=AsyncMock(), + ) + return p + + +class TestRegisterConfigHandler: + + def test_register_without_types(self, processor): + handler = AsyncMock() + processor.register_config_handler(handler) + + assert len(processor.config_handlers) == 1 + assert processor.config_handlers[0]["handler"] is handler + assert processor.config_handlers[0]["types"] is None + + def test_register_with_types(self, processor): + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) + + assert processor.config_handlers[0]["types"] == {"prompt"} + + def test_register_multiple_types(self, processor): + handler = AsyncMock() + processor.register_config_handler( + handler, types=["schema", "collection"] + ) + + assert processor.config_handlers[0]["types"] == { + "schema", "collection" + } + + def test_register_multiple_handlers(self, processor): + h1 = AsyncMock() + h2 = AsyncMock() + processor.register_config_handler(h1, types=["prompt"]) + processor.register_config_handler(h2, types=["schema"]) + + assert len(processor.config_handlers) == 2 + + +class TestOnConfigNotify: + + @pytest.mark.asyncio + async def test_skip_old_version(self, processor): + processor.config_version = 5 + + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) + + msg = Mock() + msg.value.return_value = Mock(version=3, types=["prompt"]) + + await processor.on_config_notify(msg, None, None) + + handler.assert_not_called() + + @pytest.mark.asyncio + async def test_skip_same_version(self, processor): + processor.config_version = 5 + + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) + + msg = Mock() + msg.value.return_value = Mock(version=5, types=["prompt"]) + + await processor.on_config_notify(msg, None, None) + + handler.assert_not_called() + + @pytest.mark.asyncio + async def test_skip_irrelevant_types(self, processor): + processor.config_version = 1 + + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) + + msg = Mock() + msg.value.return_value = Mock(version=2, types=["schema"]) + + await processor.on_config_notify(msg, None, None) + + handler.assert_not_called() + # Version should still be updated + assert processor.config_version == 2 + + @pytest.mark.asyncio + async def test_fetch_on_relevant_type(self, processor): + processor.config_version = 1 + + handler = AsyncMock() + processor.register_config_handler(handler, types=["prompt"]) + + # Mock fetch_config + mock_config = {"prompt": {"key": "value"}} + with patch.object( + processor, 'fetch_config', + new_callable=AsyncMock, + return_value=(mock_config, 2) + ): + msg = Mock() + msg.value.return_value = Mock(version=2, types=["prompt"]) + + await processor.on_config_notify(msg, None, None) + + handler.assert_called_once_with(mock_config, 2) + assert processor.config_version == 2 + + @pytest.mark.asyncio + async def test_handler_without_types_always_called(self, processor): + processor.config_version = 1 + + handler = AsyncMock() + processor.register_config_handler(handler) # No types = all + + mock_config = {"anything": {}} + with patch.object( + processor, 'fetch_config', + new_callable=AsyncMock, + return_value=(mock_config, 2) + ): + msg = Mock() + msg.value.return_value = Mock(version=2, types=["whatever"]) + + await processor.on_config_notify(msg, None, None) + + handler.assert_called_once_with(mock_config, 2) + + @pytest.mark.asyncio + async def test_mixed_handlers_type_filtering(self, processor): + processor.config_version = 1 + + prompt_handler = AsyncMock() + schema_handler = AsyncMock() + all_handler = AsyncMock() + + processor.register_config_handler(prompt_handler, types=["prompt"]) + processor.register_config_handler(schema_handler, types=["schema"]) + processor.register_config_handler(all_handler) + + mock_config = {"prompt": {}} + with patch.object( + processor, 'fetch_config', + new_callable=AsyncMock, + return_value=(mock_config, 2) + ): + msg = Mock() + msg.value.return_value = Mock(version=2, types=["prompt"]) + + await processor.on_config_notify(msg, None, None) + + prompt_handler.assert_called_once() + schema_handler.assert_not_called() + all_handler.assert_called_once() + + @pytest.mark.asyncio + async def test_empty_types_invokes_all(self, processor): + """Empty types list (startup signal) should invoke all handlers.""" + processor.config_version = 1 + + h1 = AsyncMock() + h2 = AsyncMock() + processor.register_config_handler(h1, types=["prompt"]) + processor.register_config_handler(h2, types=["schema"]) + + mock_config = {} + with patch.object( + processor, 'fetch_config', + new_callable=AsyncMock, + return_value=(mock_config, 2) + ): + msg = Mock() + msg.value.return_value = Mock(version=2, types=[]) + + await processor.on_config_notify(msg, None, None) + + h1.assert_called_once() + h2.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_failure_handled(self, processor): + processor.config_version = 1 + + handler = AsyncMock() + processor.register_config_handler(handler) + + with patch.object( + processor, 'fetch_config', + new_callable=AsyncMock, + side_effect=RuntimeError("Connection failed") + ): + msg = Mock() + msg.value.return_value = Mock(version=2, types=["prompt"]) + + # Should not raise + await processor.on_config_notify(msg, None, None) + + handler.assert_not_called() + + +class TestFetchConfig: + + @pytest.mark.asyncio + async def test_fetch_returns_config_and_version(self, processor): + mock_resp = Mock() + mock_resp.error = None + mock_resp.config = {"prompt": {"key": "val"}} + mock_resp.version = 42 + + mock_client = AsyncMock() + mock_client.request.return_value = mock_resp + + with patch.object( + processor, '_create_config_client', return_value=mock_client + ): + config, version = await processor.fetch_config() + + assert config == {"prompt": {"key": "val"}} + assert version == 42 + mock_client.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_raises_on_error_response(self, processor): + mock_resp = Mock() + mock_resp.error = Mock(message="not found") + mock_resp.config = {} + mock_resp.version = 0 + + mock_client = AsyncMock() + mock_client.request.return_value = mock_resp + + with patch.object( + processor, '_create_config_client', return_value=mock_client + ): + with pytest.raises(RuntimeError, match="Config error"): + await processor.fetch_config() + + mock_client.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_stops_client_on_exception(self, processor): + mock_client = AsyncMock() + mock_client.request.side_effect = TimeoutError("timeout") + + with patch.object( + processor, '_create_config_client', return_value=mock_client + ): + with pytest.raises(TimeoutError): + await processor.fetch_config() + + mock_client.stop.assert_called_once() + + +class TestFetchAndApplyConfig: + + @pytest.mark.asyncio + async def test_applies_config_to_all_handlers(self, processor): + h1 = AsyncMock() + h2 = AsyncMock() + processor.register_config_handler(h1, types=["prompt"]) + processor.register_config_handler(h2, types=["schema"]) + + mock_config = {"prompt": {}, "schema": {}} + with patch.object( + processor, 'fetch_config', + new_callable=AsyncMock, + return_value=(mock_config, 10) + ): + await processor.fetch_and_apply_config() + + # On startup, all handlers are invoked regardless of type + h1.assert_called_once_with(mock_config, 10) + h2.assert_called_once_with(mock_config, 10) + assert processor.config_version == 10 + + @pytest.mark.asyncio + async def test_retries_on_failure(self, processor): + call_count = 0 + mock_config = {"prompt": {}} + + async def mock_fetch(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError("not ready") + return mock_config, 5 + + with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \ + patch('asyncio.sleep', new_callable=AsyncMock): + await processor.fetch_and_apply_config() + + assert call_count == 3 + assert processor.config_version == 5 diff --git a/tests/unit/test_base/test_flow_processor.py b/tests/unit/test_base/test_flow_processor.py index 70835e00..8672831a 100644 --- a/tests/unit/test_base/test_flow_processor.py +++ b/tests/unit/test_base/test_flow_processor.py @@ -35,7 +35,9 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): mock_async_init.assert_called_once() # Verify register_config_handler was called with the correct handler - mock_register_config.assert_called_once_with(processor.on_configure_flows) + mock_register_config.assert_called_once_with( + processor.on_configure_flows, types=["active-flow"] + ) # Verify FlowProcessor-specific initialization assert hasattr(processor, 'flows') diff --git a/tests/unit/test_gateway/test_config_receiver.py b/tests/unit/test_gateway/test_config_receiver.py index 803ff4c6..49dc48d8 100644 --- a/tests/unit/test_gateway/test_config_receiver.py +++ b/tests/unit/test_gateway/test_config_receiver.py @@ -5,7 +5,7 @@ Tests for Gateway Config Receiver import pytest import asyncio import json -from unittest.mock import Mock, patch, Mock, MagicMock +from unittest.mock import Mock, patch, MagicMock, AsyncMock import uuid from trustgraph.gateway.config.receiver import ConfigReceiver @@ -23,174 +23,237 @@ class TestConfigReceiver: def test_config_receiver_initialization(self): """Test ConfigReceiver initialization""" mock_backend = Mock() - + config_receiver = ConfigReceiver(mock_backend) - + assert config_receiver.backend == mock_backend assert config_receiver.flow_handlers == [] assert config_receiver.flows == {} + assert config_receiver.config_version == 0 def test_add_handler(self): """Test adding flow handlers""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - + handler1 = Mock() handler2 = Mock() - + config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) - + assert len(config_receiver.flow_handlers) == 2 assert handler1 in config_receiver.flow_handlers assert handler2 in config_receiver.flow_handlers @pytest.mark.asyncio - async def test_on_config_with_new_flows(self): - """Test on_config method with new flows""" + async def test_on_config_notify_new_version(self): + """Test on_config_notify triggers fetch for newer version""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Track calls manually instead of using AsyncMock - start_flow_calls = [] - - async def mock_start_flow(*args): - start_flow_calls.append(args) - - config_receiver.start_flow = mock_start_flow - - # Create mock message with flows + config_receiver.config_version = 1 + + # Mock fetch_and_apply + fetch_calls = [] + async def mock_fetch(**kwargs): + fetch_calls.append(kwargs) + config_receiver.fetch_and_apply = mock_fetch + + # Create notify message with newer version mock_msg = Mock() - mock_msg.value.return_value = Mock( - version="1.0", - config={ - "flow": { - "flow1": '{"name": "test_flow_1", "steps": []}', - "flow2": '{"name": "test_flow_2", "steps": []}' - } - } - ) - - await config_receiver.on_config(mock_msg, None, None) - - # Verify flows were added - assert "flow1" in config_receiver.flows - assert "flow2" in config_receiver.flows - assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []} - assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []} - - # Verify start_flow was called for each new flow - assert len(start_flow_calls) == 2 - assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls - assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls + mock_msg.value.return_value = Mock(version=2, types=["flow"]) + + await config_receiver.on_config_notify(mock_msg, None, None) + + assert len(fetch_calls) == 1 @pytest.mark.asyncio - async def test_on_config_with_removed_flows(self): - """Test on_config method with removed flows""" + async def test_on_config_notify_old_version_ignored(self): + """Test on_config_notify ignores older versions""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Pre-populate with existing flows - config_receiver.flows = { - "flow1": {"name": "test_flow_1", "steps": []}, - "flow2": {"name": "test_flow_2", "steps": []} - } - - # Track calls manually instead of using AsyncMock - stop_flow_calls = [] - - async def mock_stop_flow(*args): - stop_flow_calls.append(args) - - config_receiver.stop_flow = mock_stop_flow - - # Create mock message with only flow1 (flow2 removed) + config_receiver.config_version = 5 + + fetch_calls = [] + async def mock_fetch(**kwargs): + fetch_calls.append(kwargs) + config_receiver.fetch_and_apply = mock_fetch + + # Create notify message with older version mock_msg = Mock() - mock_msg.value.return_value = Mock( - version="1.0", - config={ - "flow": { - "flow1": '{"name": "test_flow_1", "steps": []}' - } - } - ) - - await config_receiver.on_config(mock_msg, None, None) - - # Verify flow2 was removed - assert "flow1" in config_receiver.flows - assert "flow2" not in config_receiver.flows - - # Verify stop_flow was called for removed flow - assert len(stop_flow_calls) == 1 - assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []}) + mock_msg.value.return_value = Mock(version=3, types=["flow"]) + + await config_receiver.on_config_notify(mock_msg, None, None) + + assert len(fetch_calls) == 0 @pytest.mark.asyncio - async def test_on_config_with_no_flows(self): - """Test on_config method with no flows in config""" + async def test_on_config_notify_irrelevant_types_ignored(self): + """Test on_config_notify ignores types the gateway doesn't care about""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Mock the start_flow and stop_flow methods with async functions - async def mock_start_flow(*args): - pass - async def mock_stop_flow(*args): - pass - config_receiver.start_flow = mock_start_flow - config_receiver.stop_flow = mock_stop_flow - - # Create mock message without flows + config_receiver.config_version = 1 + + fetch_calls = [] + async def mock_fetch(**kwargs): + fetch_calls.append(kwargs) + config_receiver.fetch_and_apply = mock_fetch + + # Create notify message with non-flow type mock_msg = Mock() - mock_msg.value.return_value = Mock( - version="1.0", - config={} - ) - - await config_receiver.on_config(mock_msg, None, None) - - # Verify no flows were added - assert config_receiver.flows == {} - - # Since no flows were in the config, the flow methods shouldn't be called - # (We can't easily assert this with simple async functions, but the test - # passes if no exceptions are thrown) + mock_msg.value.return_value = Mock(version=2, types=["prompt"]) + + await config_receiver.on_config_notify(mock_msg, None, None) + + # Version should be updated but no fetch + assert len(fetch_calls) == 0 + assert config_receiver.config_version == 2 @pytest.mark.asyncio - async def test_on_config_exception_handling(self): - """Test on_config method handles exceptions gracefully""" + async def test_on_config_notify_flow_type_triggers_fetch(self): + """Test on_config_notify fetches for flow-related types""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Create mock message that will cause an exception + config_receiver.config_version = 1 + + fetch_calls = [] + async def mock_fetch(**kwargs): + fetch_calls.append(kwargs) + config_receiver.fetch_and_apply = mock_fetch + + for type_name in ["flow", "active-flow"]: + fetch_calls.clear() + config_receiver.config_version = 1 + + mock_msg = Mock() + mock_msg.value.return_value = Mock(version=2, types=[type_name]) + + await config_receiver.on_config_notify(mock_msg, None, None) + + assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}" + + @pytest.mark.asyncio + async def test_on_config_notify_exception_handling(self): + """Test on_config_notify handles exceptions gracefully""" + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) + + # Create notify message that causes an exception mock_msg = Mock() mock_msg.value.side_effect = Exception("Test exception") - - # This should not raise an exception - await config_receiver.on_config(mock_msg, None, None) - - # Verify flows remain empty + + # Should not raise + await config_receiver.on_config_notify(mock_msg, None, None) + + @pytest.mark.asyncio + async def test_fetch_and_apply_with_new_flows(self): + """Test fetch_and_apply starts new flows""" + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) + + # Mock config_client + mock_resp = Mock() + mock_resp.error = None + mock_resp.version = 5 + mock_resp.config = { + "flow": { + "flow1": '{"name": "test_flow_1"}', + "flow2": '{"name": "test_flow_2"}' + } + } + + mock_client = AsyncMock() + mock_client.request.return_value = mock_resp + config_receiver.config_client = mock_client + + start_flow_calls = [] + async def mock_start_flow(id, flow): + start_flow_calls.append((id, flow)) + config_receiver.start_flow = mock_start_flow + + await config_receiver.fetch_and_apply() + + assert config_receiver.config_version == 5 + assert "flow1" in config_receiver.flows + assert "flow2" in config_receiver.flows + assert len(start_flow_calls) == 2 + + @pytest.mark.asyncio + async def test_fetch_and_apply_with_removed_flows(self): + """Test fetch_and_apply stops removed flows""" + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) + + # Pre-populate with existing flows + config_receiver.flows = { + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"} + } + + # Config now only has flow1 + mock_resp = Mock() + mock_resp.error = None + mock_resp.version = 5 + mock_resp.config = { + "flow": { + "flow1": '{"name": "test_flow_1"}' + } + } + + mock_client = AsyncMock() + mock_client.request.return_value = mock_resp + config_receiver.config_client = mock_client + + stop_flow_calls = [] + async def mock_stop_flow(id, flow): + stop_flow_calls.append((id, flow)) + config_receiver.stop_flow = mock_stop_flow + + await config_receiver.fetch_and_apply() + + assert "flow1" in config_receiver.flows + assert "flow2" not in config_receiver.flows + assert len(stop_flow_calls) == 1 + assert stop_flow_calls[0][0] == "flow2" + + @pytest.mark.asyncio + async def test_fetch_and_apply_with_no_flows(self): + """Test fetch_and_apply with empty config""" + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) + + mock_resp = Mock() + mock_resp.error = None + mock_resp.version = 1 + mock_resp.config = {} + + mock_client = AsyncMock() + mock_client.request.return_value = mock_resp + config_receiver.config_client = mock_client + + await config_receiver.fetch_and_apply() + assert config_receiver.flows == {} + assert config_receiver.config_version == 1 @pytest.mark.asyncio async def test_start_flow_with_handlers(self): """Test start_flow method with multiple handlers""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Add mock handlers + handler1 = Mock() handler1.start_flow = Mock() handler2 = Mock() handler2.start_flow = Mock() - + config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) - + flow_data = {"name": "test_flow", "steps": []} - + await config_receiver.start_flow("flow1", flow_data) - - # Verify all handlers were called + handler1.start_flow.assert_called_once_with("flow1", flow_data) handler2.start_flow.assert_called_once_with("flow1", flow_data) @@ -199,19 +262,17 @@ class TestConfigReceiver: """Test start_flow method handles handler exceptions""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Add mock handler that raises exception + handler = Mock() handler.start_flow = Mock(side_effect=Exception("Handler error")) - + config_receiver.add_handler(handler) - + flow_data = {"name": "test_flow", "steps": []} - - # This should not raise an exception + + # Should not raise await config_receiver.start_flow("flow1", flow_data) - - # Verify handler was called + handler.start_flow.assert_called_once_with("flow1", flow_data) @pytest.mark.asyncio @@ -219,21 +280,19 @@ class TestConfigReceiver: """Test stop_flow method with multiple handlers""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Add mock handlers + handler1 = Mock() handler1.stop_flow = Mock() handler2 = Mock() handler2.stop_flow = Mock() - + config_receiver.add_handler(handler1) config_receiver.add_handler(handler2) - + flow_data = {"name": "test_flow", "steps": []} - + await config_receiver.stop_flow("flow1", flow_data) - - # Verify all handlers were called + handler1.stop_flow.assert_called_once_with("flow1", flow_data) handler2.stop_flow.assert_called_once_with("flow1", flow_data) @@ -242,167 +301,77 @@ class TestConfigReceiver: """Test stop_flow method handles handler exceptions""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Add mock handler that raises exception + handler = Mock() handler.stop_flow = Mock(side_effect=Exception("Handler error")) - + config_receiver.add_handler(handler) - + flow_data = {"name": "test_flow", "steps": []} - - # This should not raise an exception + + # Should not raise await config_receiver.stop_flow("flow1", flow_data) - - # Verify handler was called + handler.stop_flow.assert_called_once_with("flow1", flow_data) - @pytest.mark.asyncio - async def test_config_loader_creates_consumer(self): - """Test config_loader method creates Pulsar consumer""" - mock_backend = Mock() - - config_receiver = ConfigReceiver(mock_backend) - # Temporarily restore the real config_loader for this test - config_receiver.config_loader = _real_config_loader.__get__(config_receiver) - - # Mock Consumer class - with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \ - patch('uuid.uuid4') as mock_uuid: - - mock_uuid.return_value = "test-uuid" - mock_consumer = Mock() - async def mock_start(): - pass - mock_consumer.start = mock_start - mock_consumer_class.return_value = mock_consumer - - # Create a task that will complete quickly - async def quick_task(): - await config_receiver.config_loader() - - # Run the task with a timeout to prevent hanging - try: - await asyncio.wait_for(quick_task(), timeout=0.1) - except asyncio.TimeoutError: - # This is expected since the method runs indefinitely - pass - - # Verify Consumer was created with correct parameters - mock_consumer_class.assert_called_once() - call_args = mock_consumer_class.call_args - - assert call_args[1]['backend'] == mock_backend - assert call_args[1]['subscriber'] == "gateway-test-uuid" - assert call_args[1]['handler'] == config_receiver.on_config - assert call_args[1]['start_of_messages'] is True - @patch('asyncio.create_task') @pytest.mark.asyncio async def test_start_creates_config_loader_task(self, mock_create_task): """Test start method creates config loader task""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Mock create_task to avoid actually creating tasks with real coroutines + mock_task = Mock() mock_create_task.return_value = mock_task - + await config_receiver.start() - - # Verify task was created + mock_create_task.assert_called_once() - - # Verify the argument passed to create_task is a coroutine - call_args = mock_create_task.call_args[0] - assert len(call_args) == 1 # Should have one argument (the coroutine) @pytest.mark.asyncio - async def test_on_config_mixed_flow_operations(self): - """Test on_config with mixed add/remove operations""" + async def test_fetch_and_apply_mixed_flow_operations(self): + """Test fetch_and_apply with mixed add/remove operations""" mock_backend = Mock() config_receiver = ConfigReceiver(mock_backend) - - # Pre-populate with existing flows + + # Pre-populate config_receiver.flows = { - "flow1": {"name": "test_flow_1", "steps": []}, - "flow2": {"name": "test_flow_2", "steps": []} + "flow1": {"name": "test_flow_1"}, + "flow2": {"name": "test_flow_2"} } - - # Track calls manually instead of using Mock - start_flow_calls = [] - stop_flow_calls = [] - - async def mock_start_flow(*args): - start_flow_calls.append(args) - - async def mock_stop_flow(*args): - stop_flow_calls.append(args) - - # Directly assign to avoid patch.object detecting async methods - original_start_flow = config_receiver.start_flow - original_stop_flow = config_receiver.stop_flow + + # Config removes flow1, keeps flow2, adds flow3 + mock_resp = Mock() + mock_resp.error = None + mock_resp.version = 5 + mock_resp.config = { + "flow": { + "flow2": '{"name": "test_flow_2"}', + "flow3": '{"name": "test_flow_3"}' + } + } + + mock_client = AsyncMock() + mock_client.request.return_value = mock_resp + config_receiver.config_client = mock_client + + start_calls = [] + stop_calls = [] + + async def mock_start_flow(id, flow): + start_calls.append((id, flow)) + async def mock_stop_flow(id, flow): + stop_calls.append((id, flow)) + config_receiver.start_flow = mock_start_flow config_receiver.stop_flow = mock_stop_flow - - try: - - # Create mock message with flow1 removed and flow3 added - mock_msg = Mock() - mock_msg.value.return_value = Mock( - version="1.0", - config={ - "flow": { - "flow2": '{"name": "test_flow_2", "steps": []}', - "flow3": '{"name": "test_flow_3", "steps": []}' - } - } - ) - - await config_receiver.on_config(mock_msg, None, None) - - # Verify final state - assert "flow1" not in config_receiver.flows - assert "flow2" in config_receiver.flows - assert "flow3" in config_receiver.flows - - # Verify operations - assert len(start_flow_calls) == 1 - assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []}) - assert len(stop_flow_calls) == 1 - assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []}) - - finally: - # Restore original methods - config_receiver.start_flow = original_start_flow - config_receiver.stop_flow = original_stop_flow - @pytest.mark.asyncio - async def test_on_config_invalid_json_flow_data(self): - """Test on_config handles invalid JSON in flow data""" - mock_backend = Mock() - config_receiver = ConfigReceiver(mock_backend) - - # Mock the start_flow method with an async function - async def mock_start_flow(*args): - pass - config_receiver.start_flow = mock_start_flow - - # Create mock message with invalid JSON - mock_msg = Mock() - mock_msg.value.return_value = Mock( - version="1.0", - config={ - "flow": { - "flow1": '{"invalid": json}', # Invalid JSON - "flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON - } - } - ) - - # This should handle the exception gracefully - await config_receiver.on_config(mock_msg, None, None) - - # The entire operation should fail due to JSON parsing error - # So no flows should be added - assert config_receiver.flows == {} \ No newline at end of file + await config_receiver.fetch_and_apply() + + assert "flow1" not in config_receiver.flows + assert "flow2" in config_receiver.flows + assert "flow3" in config_receiver.flows + assert len(start_calls) == 1 + assert start_calls[0][0] == "flow3" + assert len(stop_calls) == 1 + assert stop_calls[0][0] == "flow1" diff --git a/tests/unit/test_pubsub/test_queue_naming.py b/tests/unit/test_pubsub/test_queue_naming.py index edd3dfca..8ab09e5a 100644 --- a/tests/unit/test_pubsub/test_queue_naming.py +++ b/tests/unit/test_pubsub/test_queue_naming.py @@ -153,7 +153,7 @@ class TestQueueDefinitions: def test_config_push(self): from trustgraph.schema.services.config import config_push_queue - assert config_push_queue == 'state:tg:config' + assert config_push_queue == 'flow:tg:config' def test_librarian_request(self): from trustgraph.schema.services.library import librarian_request_queue diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 7f7dbdcd..d59bff59 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -1,7 +1,8 @@ # Base class for processors. Implements: -# - Pulsar client, subscribe and consume basic +# - Pub/sub client, subscribe and consume basic # - the async startup logic +# - Config notify handling with subscribe-then-fetch pattern # - Initialising metrics import asyncio @@ -12,12 +13,17 @@ import logging import os from prometheus_client import start_http_server, Info -from .. schema import ConfigPush, config_push_queue +from .. schema import ConfigPush, ConfigRequest, ConfigResponse +from .. schema import config_push_queue, config_request_queue +from .. schema import config_response_queue from .. log_level import LogLevel from . pubsub import get_pubsub, add_pubsub_args from . producer import Producer from . consumer import Consumer -from . metrics import ProcessorMetrics, ConsumerMetrics +from . subscriber import Subscriber +from . request_response_spec import RequestResponse +from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics +from . metrics import SubscriberMetrics from . logging import add_logging_args, setup_logging default_config_queue = config_push_queue @@ -57,9 +63,13 @@ class AsyncProcessor: "config_push_queue", default_config_queue ) - # This records registered configuration handlers + # This records registered configuration handlers, each entry is: + # { "handler": async_fn, "types": set_or_none } self.config_handlers = [] + # Track the current config version for dedup + self.config_version = 0 + # Create a random ID for this subscription to the configuration # service config_subscriber_id = str(uuid.uuid4()) @@ -68,8 +78,7 @@ class AsyncProcessor: processor = self.id, flow = None, name = "config", ) - # Subscribe to config queue — exclusive so every processor - # gets its own copy of config pushes (broadcast pattern) + # Subscribe to config notify queue self.config_sub_task = Consumer( taskgroup = self.taskgroup, @@ -80,21 +89,93 @@ class AsyncProcessor: topic = self.config_push_queue, schema = ConfigPush, - handler = self.on_config_change, + handler = self.on_config_notify, metrics = config_consumer_metrics, - start_of_messages = True, + start_of_messages = False, consumer_type = 'exclusive', ) self.running = True - # This is called to start dynamic behaviour. An over-ride point for - # extra functionality + def _create_config_client(self): + """Create a short-lived config request/response client.""" + config_rr_id = str(uuid.uuid4()) + + config_req_metrics = ProducerMetrics( + processor = self.id, flow = None, name = "config-request", + ) + config_resp_metrics = SubscriberMetrics( + processor = self.id, flow = None, name = "config-response", + ) + + return RequestResponse( + backend = self.pubsub_backend, + subscription = f"{self.id}--config--{config_rr_id}", + consumer_name = self.id, + request_topic = config_request_queue, + request_schema = ConfigRequest, + request_metrics = config_req_metrics, + response_topic = config_response_queue, + response_schema = ConfigResponse, + response_metrics = config_resp_metrics, + ) + + async def fetch_config(self): + """Fetch full config from config service using a short-lived + request/response client. Returns (config, version) or raises.""" + client = self._create_config_client() + try: + await client.start() + resp = await client.request( + ConfigRequest(operation="config"), + timeout=10, + ) + if resp.error: + raise RuntimeError(f"Config error: {resp.error.message}") + return resp.config, resp.version + finally: + await client.stop() + + # This is called to start dynamic behaviour. + # Implements the subscribe-then-fetch pattern to avoid race conditions. async def start(self): + + # 1. Start the notify consumer (begins buffering incoming notifys) await self.config_sub_task.start() + # 2. Fetch current config via request/response + await self.fetch_and_apply_config() + + # 3. Any buffered notifys with version > fetched version will be + # processed by on_config_notify, which does the version check + + async def fetch_and_apply_config(self): + """Fetch full config from config service and apply to all handlers. + Retries until successful — config service may not be ready yet.""" + + while self.running: + + try: + config, version = await self.fetch_config() + + logger.info(f"Fetched config version {version}") + + self.config_version = version + + # Apply to all handlers (startup = invoke all) + for entry in self.config_handlers: + await entry["handler"](config, version) + + return + + except Exception as e: + logger.warning( + f"Config fetch failed: {e}, retrying in 2s..." + ) + await asyncio.sleep(2) + # This is called to stop all threads. An over-ride point for extra # functionality def stop(self): @@ -110,20 +191,66 @@ class AsyncProcessor: def pulsar_host(self): return self._pulsar_host # Register a new event handler for configuration change - def register_config_handler(self, handler): - self.config_handlers.append(handler) + def register_config_handler(self, handler, types=None): + self.config_handlers.append({ + "handler": handler, + "types": set(types) if types else None, + }) - # Called when a new configuration message push occurs - async def on_config_change(self, message, consumer, flow): + # Called when a config notify message arrives + async def on_config_notify(self, message, consumer, flow): - # Get configuration data and version number - config = message.value().config - version = message.value().version + notify_version = message.value().version + notify_types = set(message.value().types) - # Invoke message handlers - logger.info(f"Config change event: version={version}") - for ch in self.config_handlers: - await ch(config, version) + # Skip if we already have this version or newer + if notify_version <= self.config_version: + logger.debug( + f"Ignoring config notify v{notify_version}, " + f"already at v{self.config_version}" + ) + return + + # Check if any handler cares about the affected types + if notify_types: + any_interested = False + for entry in self.config_handlers: + handler_types = entry["types"] + if handler_types is None or notify_types & handler_types: + any_interested = True + break + + if not any_interested: + logger.debug( + f"Ignoring config notify v{notify_version}, " + f"no handlers for types {notify_types}" + ) + self.config_version = notify_version + return + + logger.info( + f"Config notify v{notify_version} types={list(notify_types)}, " + f"fetching config..." + ) + + # Fetch full config using short-lived client + try: + config, version = await self.fetch_config() + + self.config_version = version + + # Invoke handlers that care about the affected types + for entry in self.config_handlers: + handler_types = entry["types"] + if handler_types is None: + await entry["handler"](config, version) + elif not notify_types or notify_types & handler_types: + await entry["handler"](config, version) + + except Exception as e: + logger.error( + f"Failed to fetch config on notify: {e}", exc_info=True + ) # This is the 'main' body of the handler. It is a point to override # if needed. By default does nothing. Processors are implemented @@ -181,7 +308,7 @@ class AsyncProcessor: prog=ident, description=doc ) - + parser.add_argument( '--id', default=ident, @@ -271,4 +398,3 @@ class AsyncProcessor: default=8000, help=f'Pulsar host (default: 8000)', ) - diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py index 1caeaec0..4579a8c2 100644 --- a/trustgraph-base/trustgraph/base/flow_processor.py +++ b/trustgraph-base/trustgraph/base/flow_processor.py @@ -26,7 +26,9 @@ class FlowProcessor(AsyncProcessor): super(FlowProcessor, self).__init__(**params) # Register configuration handler - self.register_config_handler(self.on_configure_flows) + self.register_config_handler( + self.on_configure_flows, types=["active-flow"] + ) # Initialise flow information state self.flows = {} diff --git a/trustgraph-base/trustgraph/schema/services/config.py b/trustgraph-base/trustgraph/schema/services/config.py index 36e55674..fb219bd9 100644 --- a/trustgraph-base/trustgraph/schema/services/config.py +++ b/trustgraph-base/trustgraph/schema/services/config.py @@ -58,11 +58,11 @@ class ConfigResponse: @dataclass class ConfigPush: version: int = 0 - config: dict[str, dict[str, str]] = field(default_factory=dict) + types: list[str] = field(default_factory=list) config_request_queue = queue('config', cls='request') config_response_queue = queue('config', cls='response') -config_push_queue = queue('config', cls='state') +config_push_queue = queue('config', cls='flow') ############################################################################ diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py index 23789b96..0bc5d7e3 100755 --- a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -24,7 +24,7 @@ class Service(ToolService): **params ) - self.register_config_handler(self.on_mcp_config) + self.register_config_handler(self.on_mcp_config, types=["mcp-tool"]) self.mcp_services = {} diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py index c0a5be1e..6c897f6b 100644 --- a/trustgraph-flow/trustgraph/config/service/config.py +++ b/trustgraph-flow/trustgraph/config/service/config.py @@ -148,18 +148,7 @@ class Configuration: async def handle_delete(self, v): - # for k in v.keys: - # if k.type not in self or k.key not in self[k.type]: - # return ConfigResponse( - # version = None, - # values = None, - # directory = None, - # config = None, - # error = Error( - # type = "key-error", - # message = f"Key error" - # ) - # ) + types = list(set(k.type for k in v.keys)) for k in v.keys: @@ -167,20 +156,22 @@ class Configuration: await self.inc_version() - await self.push() + await self.push(types=types) return ConfigResponse( ) async def handle_put(self, v): + types = list(set(k.type for k in v.values)) + for k in v.values: await self.table_store.put_config(k.type, k.key, k.value) await self.inc_version() - await self.push() + await self.push(types=types) return ConfigResponse( ) diff --git a/trustgraph-flow/trustgraph/config/service/flow.py b/trustgraph-flow/trustgraph/config/service/flow.py index ab02fa30..775c8b4e 100644 --- a/trustgraph-flow/trustgraph/config/service/flow.py +++ b/trustgraph-flow/trustgraph/config/service/flow.py @@ -126,12 +126,12 @@ class FlowConfig: await self.config.inc_version() - await self.config.push() + await self.config.push(types=["flow-blueprint"]) return FlowResponse( error = None, ) - + async def handle_delete_blueprint(self, msg): logger.debug(f"Flow config message: {msg}") @@ -140,7 +140,7 @@ class FlowConfig: await self.config.inc_version() - await self.config.push() + await self.config.push(types=["flow-blueprint"]) return FlowResponse( error = None, @@ -270,7 +270,7 @@ class FlowConfig: await self.config.inc_version() - await self.config.push() + await self.config.push(types=["active-flow", "flow"]) return FlowResponse( error = None, @@ -332,12 +332,12 @@ class FlowConfig: await self.config.inc_version() - await self.config.push() + await self.config.push(types=["active-flow", "flow"]) return FlowResponse( error = None, ) - + async def handle(self, msg): logger.debug(f"Handling flow message: {msg.operation}") diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index 42b256df..5c235bb2 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -167,25 +167,22 @@ class Processor(AsyncProcessor): async def start(self): - await self.push() + await self.push() # Startup poke: empty types = everything await self.config_request_consumer.start() await self.flow_request_consumer.start() - - async def push(self): - config = await self.config.get_config() + async def push(self, types=None): + version = await self.config.get_version() resp = ConfigPush( version = version, - config = config, + types = types or [], ) await self.config_push_producer.send(resp) - # Race condition, should make sure version & config sync - - logger.info(f"Pushed configuration version {await self.config.get_version()}") + logger.info(f"Pushed config poke version {version}, types={resp.types}") async def on_config_request(self, msg, consumer, flow): diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index b8cc5f9e..3bc4e9b6 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -108,7 +108,7 @@ class Processor(AsyncProcessor): flow_config = self, ) - self.register_config_handler(self.on_knowledge_config) + self.register_config_handler(self.on_knowledge_config, types=["kg-core"]) self.flows = {} diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index 1365cb14..362bdec9 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -66,8 +66,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): ) # Register config handlers - self.register_config_handler(self.on_schema_config) - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_schema_config, types=["schema"]) + self.register_config_handler(self.on_collection_config, types=["collection"]) # Schema storage: name -> RowSchema self.schemas: Dict[str, RowSchema] = {} diff --git a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py index 5ce343c6..ce8d6aae 100644 --- a/trustgraph-flow/trustgraph/extract/kg/agent/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/agent/extract.py @@ -43,7 +43,7 @@ class Processor(FlowProcessor): self.template_id = template_id self.config_key = config_key - self.register_config_handler(self.on_prompt_config) + self.register_config_handler(self.on_prompt_config, types=["prompt"]) self.register_specification( ConsumerSpec( diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 5078d817..29808cae 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -107,7 +107,7 @@ class Processor(FlowProcessor): ) # Register config handler for ontology updates - self.register_config_handler(self.on_ontology_config) + self.register_config_handler(self.on_ontology_config, types=["ontology"]) # Shared components (not flow-specific) self.ontology_loader = OntologyLoader() diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 02aa7d78..8fd494b0 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -82,7 +82,7 @@ class Processor(FlowProcessor): ) # Register config handler for schema updates - self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_schema_config, types=["schema"]) # Schema storage: name -> RowSchema self.schemas: Dict[str, RowSchema] = {} diff --git a/trustgraph-flow/trustgraph/gateway/config/receiver.py b/trustgraph-flow/trustgraph/gateway/config/receiver.py index d956c7c6..97f4e7eb 100755 --- a/trustgraph-flow/trustgraph/gateway/config/receiver.py +++ b/trustgraph-flow/trustgraph/gateway/config/receiver.py @@ -1,36 +1,27 @@ """ -API gateway. Offers HTTP services which are translated to interaction on the -Pulsar bus. +API gateway config receiver. Subscribes to config notify notifications and +fetches full config via request/response to manage flow lifecycle. """ module = "api-gateway" -# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there -# are active listeners - -# FIXME: Connection errors in publishers / subscribers cause those threads -# to fail and are not failed or retried - import asyncio -import argparse -from aiohttp import web -import logging -import os -import base64 import uuid - -# Module logger -logger = logging.getLogger(__name__) +import logging import json -from prometheus_client import start_http_server - -from ... schema import ConfigPush, config_push_queue -from ... base import Consumer +from ... schema import ConfigPush, ConfigRequest, ConfigResponse +from ... schema import config_push_queue, config_request_queue +from ... schema import config_response_queue +from ... base import Consumer, Producer +from ... base.subscriber import Subscriber +from ... base.request_response_spec import RequestResponse +from ... base.metrics import ProducerMetrics, SubscriberMetrics logger = logging.getLogger("config.receiver") logger.setLevel(logging.INFO) + class ConfigReceiver: def __init__(self, backend): @@ -41,34 +32,107 @@ class ConfigReceiver: self.flows = {} + self.config_version = 0 + def add_handler(self, h): self.flow_handlers.append(h) - async def on_config(self, msg, proc, flow): + async def on_config_notify(self, msg, proc, flow): try: v = msg.value() + notify_version = v.version + notify_types = set(v.types) - logger.info(f"Config version: {v.version}") + # Skip if we already have this version or newer + if notify_version <= self.config_version: + logger.debug( + f"Ignoring config notify v{notify_version}, " + f"already at v{self.config_version}" + ) + return - flows = v.config.get("flow", {}) + # Gateway cares about flow config + if notify_types and "flow" not in notify_types and "active-flow" not in notify_types: + logger.debug( + f"Ignoring config notify v{notify_version}, " + f"no flow types in {notify_types}" + ) + self.config_version = notify_version + return - wanted = list(flows.keys()) - current = list(self.flows.keys()) + logger.info( + f"Config notify v{notify_version}, fetching config..." + ) - for k in wanted: - if k not in current: - self.flows[k] = json.loads(flows[k]) - await self.start_flow(k, self.flows[k]) - - for k in current: - if k not in wanted: - await self.stop_flow(k, self.flows[k]) - del self.flows[k] + await self.fetch_and_apply() except Exception as e: - logger.error(f"Config processing exception: {e}", exc_info=True) + logger.error( + f"Config notify processing exception: {e}", exc_info=True + ) + + async def fetch_and_apply(self, retry=False): + """Fetch full config and apply flow changes. + If retry=True, keeps retrying until successful.""" + + while True: + + try: + logger.info("Fetching config from config service...") + + resp = await self.config_client.request( + ConfigRequest(operation="config"), + timeout=10, + ) + + logger.info(f"Config response received") + + if resp.error: + if retry: + logger.warning( + f"Config fetch error: {resp.error.message}, " + f"retrying in 2s..." + ) + await asyncio.sleep(2) + continue + logger.error( + f"Config fetch error: {resp.error.message}" + ) + return + + self.config_version = resp.version + config = resp.config + + flows = config.get("flow", {}) + + wanted = list(flows.keys()) + current = list(self.flows.keys()) + + for k in wanted: + if k not in current: + self.flows[k] = json.loads(flows[k]) + await self.start_flow(k, self.flows[k]) + + for k in current: + if k not in wanted: + await self.stop_flow(k, self.flows[k]) + del self.flows[k] + + return + + except Exception as e: + if retry: + logger.warning( + f"Config fetch failed: {e}, retrying in 2s..." + ) + await asyncio.sleep(2) + continue + logger.error( + f"Config fetch exception: {e}", exc_info=True + ) + return async def start_flow(self, id, flow): @@ -79,7 +143,9 @@ class ConfigReceiver: try: await handler.start_flow(id, flow) except Exception as e: - logger.error(f"Config processing exception: {e}", exc_info=True) + logger.error( + f"Config processing exception: {e}", exc_info=True + ) async def stop_flow(self, id, flow): @@ -90,32 +156,80 @@ class ConfigReceiver: try: await handler.stop_flow(id, flow) except Exception as e: - logger.error(f"Config processing exception: {e}", exc_info=True) + logger.error( + f"Config processing exception: {e}", exc_info=True + ) async def config_loader(self): - async with asyncio.TaskGroup() as tg: + while True: - id = str(uuid.uuid4()) + try: - self.config_cons = Consumer( - taskgroup = tg, - flow = None, - backend = self.backend, - subscriber = f"gateway-{id}", - topic = config_push_queue, - schema = ConfigPush, - handler = self.on_config, - start_of_messages = True, - ) + async with asyncio.TaskGroup() as tg: - await self.config_cons.start() + id = str(uuid.uuid4()) - logger.debug("Waiting for config updates...") + # Config request/response client + config_req_metrics = ProducerMetrics( + processor="api-gateway", flow=None, + name="config-request", + ) + config_resp_metrics = SubscriberMetrics( + processor="api-gateway", flow=None, + name="config-response", + ) - logger.info("Config consumer finished") + self.config_client = RequestResponse( + backend=self.backend, + subscription=f"api-gateway--config--{id}", + consumer_name="api-gateway", + request_topic=config_request_queue, + request_schema=ConfigRequest, + request_metrics=config_req_metrics, + response_topic=config_response_queue, + response_schema=ConfigResponse, + response_metrics=config_resp_metrics, + ) + + logger.info("Starting config request/response client...") + await self.config_client.start() + logger.info("Config request/response client started") + + # Subscribe to notify queue + self.config_cons = Consumer( + taskgroup=tg, + flow=None, + backend=self.backend, + subscriber=f"gateway-{id}", + topic=config_push_queue, + schema=ConfigPush, + handler=self.on_config_notify, + start_of_messages=False, + ) + + logger.info("Starting config notify consumer...") + await self.config_cons.start() + logger.info("Config notify consumer started") + + # Fetch current config (subscribe-then-fetch pattern) + # Retry until config service is available + await self.fetch_and_apply(retry=True) + + logger.info( + "Config loader initialised, waiting for notifys..." + ) + + logger.warning("Config consumer exited, restarting...") + + except Exception as e: + logger.error( + f"Config loader exception: {e}, restarting in 4s...", + exc_info=True + ) + + await asyncio.sleep(4) async def start(self): - - asyncio.create_task(self.config_loader()) + asyncio.create_task(self.config_loader()) diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index 4f8f5465..15cc97fa 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -246,7 +246,7 @@ class Processor(AsyncProcessor): taskgroup = self.taskgroup, ) - self.register_config_handler(self.on_librarian_config) + self.register_config_handler(self.on_librarian_config, types=["librarian"]) self.flows = {} diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 7851232a..f120a812 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -40,7 +40,7 @@ class Processor(FlowProcessor): } ) - self.register_config_handler(self.on_cost_config) + self.register_config_handler(self.on_cost_config, types=["token-costs"]) self.register_specification( ConsumerSpec( diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py index 5fc177d5..97298e13 100755 --- a/trustgraph-flow/trustgraph/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/prompt/template/service.py @@ -65,7 +65,7 @@ class Processor(FlowProcessor): ) ) - self.register_config_handler(self.on_prompt_config) + self.register_config_handler(self.on_prompt_config, types=["prompt"]) # Null configuration, should reload quickly self.manager = PromptManager() diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 2337642f..f928a911 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -84,7 +84,7 @@ class Processor(FlowProcessor): ) # Register config handler for schema updates - self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_schema_config, types=["schema"]) # Schema storage: name -> RowSchema self.schemas: Dict[str, RowSchema] = {} diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py index 04dae978..b567cc7b 100644 --- a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py @@ -64,7 +64,7 @@ class Processor(FlowProcessor): ) # Register config handler for schema updates - self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_schema_config, types=["schema"]) # Schema storage: name -> RowSchema self.schemas: Dict[str, RowSchema] = {} diff --git a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py index d69c8f17..b878bf61 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_diag/service.py @@ -70,7 +70,7 @@ class Processor(FlowProcessor): ) # Register config handler for schema updates - self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_schema_config, types=["schema"]) # Schema storage: name -> RowSchema self.schemas: Dict[str, RowSchema] = {} diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index e282f876..f5c12441 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -31,7 +31,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): self.vecstore = DocVectors(store_uri) # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) async def store_document_embeddings(self, message): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index ea091d35..31a70f23 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -58,7 +58,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): self.last_index_name = None # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) def create_index(self, index_name, dim): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index a87f2128..e5e7e705 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -37,7 +37,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): self.qdrant = QdrantClient(url=store_uri, api_key=api_key) # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) async def store_document_embeddings(self, message): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 0f27adf9..9346c948 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -45,7 +45,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): self.vecstore = EntityVectors(store_uri) # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) async def store_graph_embeddings(self, message): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index d907e873..6a95a38d 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -72,7 +72,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): self.last_index_name = None # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) def create_index(self, index_name, dim): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index f887d487..9a7672f8 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -52,7 +52,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): self.qdrant = QdrantClient(url=store_uri, api_key=api_key) # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) async def store_graph_embeddings(self, message): diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 42e59012..a6ec4ff7 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -61,7 +61,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): ) # Register config handler for collection management - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) # Cache of created Qdrant collections self.created_collections: Set[str] = set() diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index d15916b6..673cba4d 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -75,8 +75,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): ) # Register config handlers - self.register_config_handler(self.on_schema_config) - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_schema_config, types=["schema"]) + self.register_config_handler(self.on_collection_config, types=["collection"]) # Cache of known keyspaces and whether tables exist self.known_keyspaces: Set[str] = set() diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index d31d6223..2a240f0b 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -144,7 +144,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.tg = None # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) async def store_triples(self, message): diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index ac8d05c4..86f9a6e3 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -57,7 +57,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.io = FalkorDB.from_url(graph_url).select_graph(database) # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) def create_node(self, uri, user, collection): diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 7864ac80..16a7d3ed 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -66,7 +66,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.create_indexes(session) # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) def create_indexes(self, session): diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index 3db712fb..f7b2d947 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -66,7 +66,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.create_indexes(session) # Register for config push notifications - self.register_config_handler(self.on_collection_config) + self.register_config_handler(self.on_collection_config, types=["collection"]) def create_indexes(self, session): From f0c9039b7606823132b3b23ed11190702e3da6ec Mon Sep 17 00:00:00 2001 From: Sreeram Venkatasubramanian Date: Tue, 7 Apr 2026 14:19:59 +0530 Subject: [PATCH 33/37] fix: reduce consumer poll timeout from 2000ms to 100ms --- .../test_consumer_concurrency.py | 35 +++++++++++++++++++ trustgraph-base/trustgraph/base/consumer.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 3869aaf3..03244b73 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -266,6 +266,41 @@ class TestMetricsIntegration: mock_metrics.rate_limit.assert_called_once() +# --------------------------------------------------------------------------- +# Poll timeout +# --------------------------------------------------------------------------- + +class TestPollTimeout: + + @pytest.mark.asyncio + async def test_poll_timeout_is_100ms(self): + """Consumer receive timeout should be 100ms, not the original 2000ms. + + A 2000ms poll timeout means every service adds up to 2s of idle + blocking between message bursts. With many sequential hops in a + query pipeline, this compounds into seconds of unnecessary latency. + 100ms keeps responsiveness high without significant CPU overhead. + """ + consumer = _make_consumer() + + # Wire up a mock Pulsar consumer that records the receive kwargs + mock_pulsar_consumer = MagicMock() + received_kwargs = {} + + def capture_receive(**kwargs): + received_kwargs.update(kwargs) + # Stop after one call + consumer.running = False + raise type('Timeout', (Exception,), {})("timeout") + + mock_pulsar_consumer.receive = capture_receive + consumer.consumer = mock_pulsar_consumer + + await consumer.consume_from_queue() + + assert received_kwargs.get("timeout_millis") == 100 + + # --------------------------------------------------------------------------- # Stop / running flag # --------------------------------------------------------------------------- diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 9ae35d49..4f8c9de5 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -165,7 +165,7 @@ class Consumer: try: msg = await asyncio.to_thread( consumer.receive, - timeout_millis=2000 + timeout_millis=100 ) except Exception as e: # Handle timeout from any backend From 2f8d6a3ffba2c003174e9c67dc6f1374a0921368 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 7 Apr 2026 12:11:12 +0100 Subject: [PATCH 34/37] Fix agent config handler registration, remove debug prints, disable RabbitMQ heartbeats (#764) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix agent react and orchestrator services appending bare methods to config_handlers instead of using register_config_handler() — caused 'method object is not subscriptable' on config notify - Add exc_info to config fetch retry logging for proper tracebacks - Remove debug print statements from collection management dispatcher and translator - Disable RabbitMQ heartbeats (heartbeat=0) to prevent broker closing idle producer connections that can't process heartbeat frames from BlockingConnection --- trustgraph-base/trustgraph/base/async_processor.py | 3 ++- trustgraph-base/trustgraph/base/rabbitmq_backend.py | 1 + .../trustgraph/messaging/translators/collection.py | 3 --- trustgraph-flow/trustgraph/agent/orchestrator/service.py | 4 +++- trustgraph-flow/trustgraph/agent/react/service.py | 4 +++- .../trustgraph/gateway/dispatch/collection_management.py | 2 -- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index d59bff59..c805bffa 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -172,7 +172,8 @@ class AsyncProcessor: except Exception as e: logger.warning( - f"Config fetch failed: {e}, retrying in 2s..." + f"Config fetch failed: {e}, retrying in 2s...", + exc_info=True ) await asyncio.sleep(2) diff --git a/trustgraph-base/trustgraph/base/rabbitmq_backend.py b/trustgraph-base/trustgraph/base/rabbitmq_backend.py index a80efbaf..b9afe741 100644 --- a/trustgraph-base/trustgraph/base/rabbitmq_backend.py +++ b/trustgraph-base/trustgraph/base/rabbitmq_backend.py @@ -288,6 +288,7 @@ class RabbitMQBackend: port=port, virtual_host=vhost, credentials=pika.PlainCredentials(username, password), + heartbeat=0, ) logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}") diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index c6fd1500..2e39e8c2 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -79,7 +79,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): def encode(self, obj: CollectionManagementResponse) -> Dict[str, Any]: result = {} - print("COLLECTIONMGMT", obj, flush=True) if obj.error is not None: result["error"] = { @@ -99,6 +98,4 @@ class CollectionManagementResponseTranslator(MessageTranslator): "tags": list(coll.tags) if coll.tags else [] }) - print("RESULT IS", result, flush=True) - return result diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index ea0afd60..5bf8e2fd 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -93,7 +93,9 @@ class Processor(AgentService): # Meta-router (initialised on first config load) self.meta_router = None - self.config_handlers.append(self.on_tools_config) + self.register_config_handler( + self.on_tools_config, types=["tool", "tool-service"] + ) self.register_specification( TextCompletionClientSpec( diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 0e783349..40857313 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -81,7 +81,9 @@ class Processor(AgentService): # Track active tool service clients for cleanup self.tool_service_clients = {} - self.config_handlers.append(self.on_tools_config) + self.register_config_handler( + self.on_tools_config, types=["tool", "tool-service"] + ) self.register_specification( TextCompletionClientSpec( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py index 544a412d..934f1c99 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py @@ -28,9 +28,7 @@ class CollectionManagementRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("collection-management") def to_request(self, body): - print("REQUEST", body, flush=True) return self.request_translator.decode(body) def from_response(self, message): - print("RESPONSE", message, flush=True) return self.response_translator.encode_with_completion(message) From ddd4bd77903237e8048ab10ca9bf9f81db92b8e9 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 7 Apr 2026 12:19:05 +0100 Subject: [PATCH 35/37] Deliver explainability triples inline in retrieval response stream (#763) Provenance triples are now included directly in explain messages from GraphRAG, DocumentRAG, and Agent services, eliminating the need for follow-up knowledge graph queries to retrieve explainability details. Each explain message in the response stream now carries: - explain_id: root URI for this provenance step (unchanged) - explain_graph: named graph where triples are stored (unchanged) - explain_triples: the actual provenance triples for this step (new) Changes across the stack: - Schema: added explain_triples field to GraphRagResponse, DocumentRagResponse, and AgentResponse - Services: all explain message call sites pass triples through (graph_rag, document_rag, agent react, agent orchestrator) - Translators: encode explain_triples via TripleTranslator for gateway wire format - Python SDK: ProvenanceEvent now includes parsed ExplainEntity and raw triples; expanded event_type detection - CLI: invoke_graph_rag, invoke_agent, invoke_document_rag use inline entity when available, fall back to graph query - Tech specs updated Additional explainability test --- docs/tech-specs/agent-explainability.md | 4 +- docs/tech-specs/query-time-explainability.md | 11 +- .../unit/test_gateway/test_explain_triples.py | 359 ++++++++++++++++++ .../trustgraph/api/socket_client.py | 46 ++- trustgraph-base/trustgraph/api/types.py | 26 +- .../trustgraph/messaging/translators/agent.py | 15 +- .../messaging/translators/retrieval.py | 21 + .../trustgraph/schema/services/agent.py | 7 +- .../trustgraph/schema/services/retrieval.py | 14 +- trustgraph-cli/trustgraph/cli/invoke_agent.py | 16 +- .../trustgraph/cli/invoke_document_rag.py | 16 +- .../trustgraph/cli/invoke_graph_rag.py | 16 +- .../agent/orchestrator/pattern_base.py | 9 + .../trustgraph/agent/react/service.py | 4 + .../trustgraph/retrieval/document_rag/rag.py | 3 +- .../trustgraph/retrieval/graph_rag/rag.py | 3 +- 16 files changed, 521 insertions(+), 49 deletions(-) create mode 100644 tests/unit/test_gateway/test_explain_triples.py diff --git a/docs/tech-specs/agent-explainability.md b/docs/tech-specs/agent-explainability.md index f02a95b1..3dee0ac2 100644 --- a/docs/tech-specs/agent-explainability.md +++ b/docs/tech-specs/agent-explainability.md @@ -219,8 +219,8 @@ TG_ANSWER = TG + "answer" | `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders | | `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators | | `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions | -| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id and explain_graph to DocumentRagResponse | -| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields | +| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id, explain_graph, and explain_triples to DocumentRagResponse | +| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields including inline triples | | `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic | | `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples | | `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback | diff --git a/docs/tech-specs/query-time-explainability.md b/docs/tech-specs/query-time-explainability.md index e696745c..69cb45ac 100644 --- a/docs/tech-specs/query-time-explainability.md +++ b/docs/tech-specs/query-time-explainability.md @@ -63,7 +63,11 @@ Explainability events stream to client as the query executes: 3. Edges selected with reasoning → event emitted 4. Answer synthesized → event emitted -Client receives `explain_id` and `explain_collection` to fetch full details. +Client receives `explain_id`, `explain_graph`, and `explain_triples` inline +in each explain message. The triples contain the full provenance data for +that step — no follow-up graph query needed. The `explain_id` serves as +the root entity URI within the triples. Data is also written to the +knowledge graph for later audit/analysis. ## URI Structure @@ -144,7 +148,8 @@ class GraphRagResponse: response: str = "" end_of_stream: bool = False explain_id: str | None = None - explain_collection: str | None = None + explain_graph: str | None = None + explain_triples: list[Triple] = field(default_factory=list) message_type: str = "" # "chunk" or "explain" end_of_session: bool = False ``` @@ -154,7 +159,7 @@ class GraphRagResponse: | message_type | Purpose | |--------------|---------| | `chunk` | Response text (streaming or final) | -| `explain` | Explainability event with IRI reference | +| `explain` | Explainability event with inline provenance triples | ### Session Lifecycle diff --git a/tests/unit/test_gateway/test_explain_triples.py b/tests/unit/test_gateway/test_explain_triples.py new file mode 100644 index 00000000..24e77410 --- /dev/null +++ b/tests/unit/test_gateway/test_explain_triples.py @@ -0,0 +1,359 @@ +""" +Tests for inline explainability triples in response translators +and ProvenanceEvent parsing. +""" + +import pytest +from trustgraph.schema import ( + GraphRagResponse, DocumentRagResponse, AgentResponse, + Term, Triple, IRI, LITERAL, Error, +) +from trustgraph.messaging.translators.retrieval import ( + GraphRagResponseTranslator, + DocumentRagResponseTranslator, +) +from trustgraph.messaging.translators.agent import ( + AgentResponseTranslator, +) +from trustgraph.api.types import ProvenanceEvent + + +# --- Helpers --- + +def make_triple(s_iri, p_iri, o_value, o_type=LITERAL): + """Create a Triple with IRI subject/predicate and typed object.""" + o = Term(type=IRI, iri=o_value) if o_type == IRI else Term(type=LITERAL, value=o_value) + return Triple( + s=Term(type=IRI, iri=s_iri), + p=Term(type=IRI, iri=p_iri), + o=o, + ) + + +def sample_triples(): + """A few provenance triples for a question entity.""" + return [ + make_triple( + "urn:trustgraph:question:abc123", + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + "https://trustgraph.ai/ns/GraphRagQuestion", + o_type=IRI, + ), + make_triple( + "urn:trustgraph:question:abc123", + "https://trustgraph.ai/ns/query", + "What is the internet?", + ), + make_triple( + "urn:trustgraph:question:abc123", + "http://www.w3.org/ns/prov#startedAtTime", + "2026-04-07T09:00:00Z", + ), + ] + + +# --- GraphRag Translator --- + +class TestGraphRagExplainTriples: + + def test_explain_triples_encoded(self): + translator = GraphRagResponseTranslator() + triples = sample_triples() + + response = GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc123", + explain_graph="urn:graph:retrieval", + explain_triples=triples, + ) + + result = translator.encode(response) + + assert "explain_triples" in result + assert len(result["explain_triples"]) == 3 + + # Check first triple is properly encoded + t = result["explain_triples"][0] + assert t["s"]["t"] == "i" + assert t["s"]["i"] == "urn:trustgraph:question:abc123" + assert t["p"]["t"] == "i" + + def test_explain_triples_empty_not_included(self): + translator = GraphRagResponseTranslator() + + response = GraphRagResponse( + message_type="chunk", + response="Some answer text", + ) + + result = translator.encode(response) + + assert "explain_triples" not in result + + def test_explain_with_completion_returns_not_final(self): + translator = GraphRagResponseTranslator() + + response = GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc123", + explain_triples=sample_triples(), + end_of_session=False, + ) + + result, is_final = translator.encode_with_completion(response) + assert is_final is False + + def test_explain_id_and_graph_included(self): + translator = GraphRagResponseTranslator() + + response = GraphRagResponse( + message_type="explain", + explain_id="urn:trustgraph:question:abc123", + explain_graph="urn:graph:retrieval", + explain_triples=sample_triples(), + ) + + result = translator.encode(response) + assert result["explain_id"] == "urn:trustgraph:question:abc123" + assert result["explain_graph"] == "urn:graph:retrieval" + + +# --- DocumentRag Translator --- + +class TestDocumentRagExplainTriples: + + def test_explain_triples_encoded(self): + translator = DocumentRagResponseTranslator() + + response = DocumentRagResponse( + response=None, + message_type="explain", + explain_id="urn:trustgraph:docrag:abc123", + explain_graph="urn:graph:retrieval", + explain_triples=sample_triples(), + ) + + result = translator.encode(response) + + assert "explain_triples" in result + assert len(result["explain_triples"]) == 3 + + def test_explain_triples_empty_not_included(self): + translator = DocumentRagResponseTranslator() + + response = DocumentRagResponse( + response="Answer text", + message_type="chunk", + ) + + result = translator.encode(response) + assert "explain_triples" not in result + + +# --- Agent Translator --- + +class TestAgentExplainTriples: + + def test_explain_triples_encoded(self): + translator = AgentResponseTranslator() + + response = AgentResponse( + chunk_type="explain", + content="", + explain_id="urn:trustgraph:agent:session:abc123", + explain_graph="urn:graph:retrieval", + explain_triples=sample_triples(), + ) + + result = translator.encode(response) + + assert "explain_triples" in result + assert len(result["explain_triples"]) == 3 + + t = result["explain_triples"][1] + assert t["p"]["i"] == "https://trustgraph.ai/ns/query" + assert t["o"]["t"] == "l" + assert t["o"]["v"] == "What is the internet?" + + def test_explain_triples_empty_not_included(self): + translator = AgentResponseTranslator() + + response = AgentResponse( + chunk_type="thought", + content="I need to think...", + ) + + result = translator.encode(response) + assert "explain_triples" not in result + + def test_explain_with_completion_not_final(self): + translator = AgentResponseTranslator() + + response = AgentResponse( + chunk_type="explain", + explain_id="urn:trustgraph:agent:session:abc123", + explain_triples=sample_triples(), + end_of_dialog=False, + ) + + result, is_final = translator.encode_with_completion(response) + assert is_final is False + + def test_explain_with_completion_final(self): + translator = AgentResponseTranslator() + + response = AgentResponse( + chunk_type="answer", + content="The answer is...", + end_of_dialog=True, + ) + + result, is_final = translator.encode_with_completion(response) + assert is_final is True + + +# --- ProvenanceEvent --- + +class TestProvenanceEvent: + + def test_question_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:question:abc123", + ) + assert event.event_type == "question" + + def test_exploration_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:exploration:abc123", + ) + assert event.event_type == "exploration" + + def test_focus_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:focus:abc123", + ) + assert event.event_type == "focus" + + def test_synthesis_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:synthesis:abc123", + ) + assert event.event_type == "synthesis" + + def test_grounding_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:grounding:abc123", + ) + assert event.event_type == "grounding" + + def test_session_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:session:abc123", + ) + assert event.event_type == "session" + + def test_iteration_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:iteration:abc123:1", + ) + assert event.event_type == "iteration" + + def test_observation_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:observation:abc123:1", + ) + assert event.event_type == "observation" + + def test_conclusion_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:conclusion:abc123", + ) + assert event.event_type == "conclusion" + + def test_decomposition_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:decomposition:abc123", + ) + assert event.event_type == "decomposition" + + def test_finding_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:finding:abc123:0", + ) + assert event.event_type == "finding" + + def test_plan_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:plan:abc123", + ) + assert event.event_type == "plan" + + def test_step_result_event_type(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:agent:step-result:abc123:0", + ) + assert event.event_type == "step-result" + + def test_defaults(self): + event = ProvenanceEvent( + explain_id="urn:trustgraph:question:abc123", + ) + assert event.entity is None + assert event.triples == [] + assert event.explain_graph == "" + + def test_with_triples(self): + raw = [{"s": {"t": "i", "i": "urn:x"}, "p": {"t": "i", "i": "urn:y"}, "o": {"t": "l", "v": "z"}}] + event = ProvenanceEvent( + explain_id="urn:trustgraph:question:abc123", + triples=raw, + ) + assert len(event.triples) == 1 + + +# --- Build ProvenanceEvent with entity parsing --- + +class TestBuildProvenanceEvent: + + def _make_client(self): + """Create a minimal WebSocketClient-like object with _build_provenance_event.""" + from trustgraph.api.socket_client import WebSocketClient + # We can't instantiate WebSocketClient easily, so test the method logic directly + return None + + def test_entity_parsed_from_wire_triples(self): + """Test that wire-format triples are parsed into an ExplainEntity.""" + from trustgraph.api.explainability import ExplainEntity + + wire_triples = [ + { + "s": {"t": "i", "i": "urn:trustgraph:question:abc123"}, + "p": {"t": "i", "i": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}, + "o": {"t": "i", "i": "https://trustgraph.ai/ns/GraphRagQuestion"}, + }, + { + "s": {"t": "i", "i": "urn:trustgraph:question:abc123"}, + "p": {"t": "i", "i": "https://trustgraph.ai/ns/query"}, + "o": {"t": "l", "v": "What is the internet?"}, + }, + ] + + # Parse triples the same way _build_provenance_event does + parsed = [] + for t in wire_triples: + s = t.get("s", {}).get("i", "") + p = t.get("p", {}).get("i", "") + o_term = t.get("o", {}) + if o_term.get("t") == "i": + o = o_term.get("i", "") + else: + o = o_term.get("v", "") + parsed.append((s, p, o)) + + entity = ExplainEntity.from_triples( + "urn:trustgraph:question:abc123", parsed + ) + + assert entity.entity_type == "question" + assert entity.query == "What is the internet?" + assert entity.question_type == "graph-rag" diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 9c37a9b1..b6ceba00 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -366,19 +366,13 @@ class SocketClient: # Handle GraphRAG/DocRAG message format with message_type if message_type == "explain": if include_provenance: - return ProvenanceEvent( - explain_id=resp.get("explain_id", ""), - explain_graph=resp.get("explain_graph", "") - ) + return self._build_provenance_event(resp) return None # Handle Agent message format with chunk_type="explain" if chunk_type == "explain": if include_provenance: - return ProvenanceEvent( - explain_id=resp.get("explain_id", ""), - explain_graph=resp.get("explain_graph", "") - ) + return self._build_provenance_event(resp) return None if chunk_type == "thought": @@ -413,6 +407,42 @@ class SocketClient: error=None ) + def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent: + """Build a ProvenanceEvent from a response dict, parsing inline triples + into an ExplainEntity if available.""" + explain_id = resp.get("explain_id", "") + explain_graph = resp.get("explain_graph", "") + raw_triples = resp.get("explain_triples", []) + + entity = None + if raw_triples: + try: + from .explainability import ExplainEntity + # Convert wire-format triple dicts to (s, p, o) tuples + parsed = [] + for t in raw_triples: + s = t.get("s", {}).get("i", "") if t.get("s") else "" + p = t.get("p", {}).get("i", "") if t.get("p") else "" + o_term = t.get("o", {}) + if o_term: + if o_term.get("t") == "i": + o = o_term.get("i", "") + else: + o = o_term.get("v", "") + else: + o = "" + parsed.append((s, p, o)) + entity = ExplainEntity.from_triples(explain_id, parsed) + except Exception: + pass + + return ProvenanceEvent( + explain_id=explain_id, + explain_graph=explain_graph, + entity=entity, + triples=raw_triples, + ) + def close(self) -> None: """Close the persistent WebSocket connection.""" if self._loop and not self._loop.is_closed(): diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 0715293b..55635584 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -213,25 +213,47 @@ class ProvenanceEvent: """ Provenance event for explainability. - Emitted during GraphRAG queries when explainable mode is enabled. + Emitted during retrieval queries when explainable mode is enabled. Each event represents a provenance node created during query processing. Attributes: explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123) explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval) - event_type: Type of provenance event (question, exploration, focus, synthesis) + event_type: Type of provenance event (question, exploration, focus, synthesis, etc.) + entity: Parsed ExplainEntity from inline triples (if available) + triples: Raw triples from the response (wire format dicts) """ explain_id: str explain_graph: str = "" event_type: str = "" # Derived from explain_id + entity: object = None # ExplainEntity (parsed from triples) + triples: list = dataclasses.field(default_factory=list) # Raw wire-format triple dicts def __post_init__(self): # Extract event type from explain_id if "question" in self.explain_id: self.event_type = "question" + elif "grounding" in self.explain_id: + self.event_type = "grounding" elif "exploration" in self.explain_id: self.event_type = "exploration" elif "focus" in self.explain_id: self.event_type = "focus" elif "synthesis" in self.explain_id: self.event_type = "synthesis" + elif "iteration" in self.explain_id: + self.event_type = "iteration" + elif "observation" in self.explain_id: + self.event_type = "observation" + elif "conclusion" in self.explain_id: + self.event_type = "conclusion" + elif "decomposition" in self.explain_id: + self.event_type = "decomposition" + elif "finding" in self.explain_id: + self.event_type = "finding" + elif "plan" in self.explain_id: + self.event_type = "plan" + elif "step-result" in self.explain_id: + self.event_type = "step-result" + elif "session" in self.explain_id: + self.event_type = "session" diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index c2c00ac2..8cf525f5 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple from ...schema import AgentRequest, AgentResponse from .base import MessageTranslator +from .primitives import TripleTranslator class AgentRequestTranslator(MessageTranslator): @@ -49,10 +50,13 @@ class AgentRequestTranslator(MessageTranslator): class AgentResponseTranslator(MessageTranslator): """Translator for AgentResponse schema objects""" - + + def __init__(self): + self.triple_translator = TripleTranslator() + def decode(self, data: Dict[str, Any]) -> AgentResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def encode(self, obj: AgentResponse) -> Dict[str, Any]: result = {} @@ -75,6 +79,13 @@ class AgentResponseTranslator(MessageTranslator): if explain_graph is not None: result["explain_graph"] = explain_graph + # Include explain_triples for explain messages + explain_triples = getattr(obj, "explain_triples", []) + if explain_triples: + result["explain_triples"] = [ + self.triple_translator.encode(t) for t in explain_triples + ] + # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "code": obj.error.code} diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 7e2abfa1..849bee94 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse from .base import MessageTranslator +from .primitives import TripleTranslator class DocumentRagRequestTranslator(MessageTranslator): @@ -28,6 +29,9 @@ class DocumentRagRequestTranslator(MessageTranslator): class DocumentRagResponseTranslator(MessageTranslator): """Translator for DocumentRagResponse schema objects""" + def __init__(self): + self.triple_translator = TripleTranslator() + def decode(self, data: Dict[str, Any]) -> DocumentRagResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") @@ -53,6 +57,13 @@ class DocumentRagResponseTranslator(MessageTranslator): if explain_graph is not None: result["explain_graph"] = explain_graph + # Include explain_triples for explain messages + explain_triples = getattr(obj, "explain_triples", []) + if explain_triples: + result["explain_triples"] = [ + self.triple_translator.encode(t) for t in explain_triples + ] + # Include end_of_stream flag (LLM stream complete) result["end_of_stream"] = getattr(obj, "end_of_stream", False) @@ -107,6 +118,9 @@ class GraphRagRequestTranslator(MessageTranslator): class GraphRagResponseTranslator(MessageTranslator): """Translator for GraphRagResponse schema objects""" + def __init__(self): + self.triple_translator = TripleTranslator() + def decode(self, data: Dict[str, Any]) -> GraphRagResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") @@ -132,6 +146,13 @@ class GraphRagResponseTranslator(MessageTranslator): if explain_graph is not None: result["explain_graph"] = explain_graph + # Include explain_triples for explain messages + explain_triples = getattr(obj, "explain_triples", []) + if explain_triples: + result["explain_triples"] = [ + self.triple_translator.encode(t) for t in explain_triples + ] + # Include end_of_stream flag (LLM stream complete) result["end_of_stream"] = getattr(obj, "end_of_stream", False) diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index 2a966dd4..fbc0101c 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Optional -from ..core.primitives import Error +from ..core.primitives import Error, Triple ############################################################################ @@ -57,8 +57,9 @@ class AgentResponse: end_of_dialog: bool = False # Entire agent dialog is complete # Explainability fields - explain_id: str | None = None # Provenance URI (announced as created) - explain_graph: str | None = None # Named graph where explain was stored + explain_id: str | None = None # Root URI for this explain step + explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval) + explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step # Orchestration fields message_id: str = "" # Unique ID for this response message diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index a4621549..4b17733d 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass -from ..core.primitives import Error, Term +from dataclasses import dataclass, field +from ..core.primitives import Error, Term, Triple ############################################################################ @@ -24,8 +24,9 @@ class GraphRagResponse: error: Error | None = None response: str = "" end_of_stream: bool = False # LLM response stream complete - explain_id: str | None = None # Single explain URI (announced as created) - explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval) + explain_id: str | None = None # Root URI for this explain step + explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval) + explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step message_type: str = "" # "chunk" or "explain" end_of_session: bool = False # Entire session complete @@ -46,7 +47,8 @@ class DocumentRagResponse: error: Error | None = None response: str | None = "" end_of_stream: bool = False # LLM response stream complete - explain_id: str | None = None # Single explain URI (announced as created) - explain_graph: str | None = None # Named graph where explain was stored (e.g., urn:graph:retrieval) + explain_id: str | None = None # Root URI for this explain step + explain_graph: str | None = None # Named graph (e.g., urn:graph:retrieval) + explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step message_type: str = "" # "chunk" or "explain" end_of_session: bool = False # Entire session complete diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 1c4b757b..026286d0 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -182,16 +182,18 @@ def question_explainable( print(item.content, end="", flush=True) elif isinstance(item, ProvenanceEvent): - # Process provenance event immediately + # Use inline entity if available, otherwise fetch from graph prov_id = item.explain_id explain_graph = item.explain_graph or "urn:graph:retrieval" - entity = explain_client.fetch_entity( - prov_id, - graph=explain_graph, - user=user, - collection=collection - ) + entity = item.entity + if entity is None: + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) if entity is None: if debug: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 7da9d779..066b92f4 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -45,16 +45,18 @@ def question_explainable( print(item.content, end="", flush=True) elif isinstance(item, ProvenanceEvent): - # Process provenance event immediately + # Use inline entity if available, otherwise fetch from graph prov_id = item.explain_id explain_graph = item.explain_graph or "urn:graph:retrieval" - entity = explain_client.fetch_entity( - prov_id, - graph=explain_graph, - user=user, - collection=collection - ) + entity = item.entity + if entity is None: + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) if entity is None: if debug: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 76b8b158..230cc54b 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -667,16 +667,18 @@ def _question_explainable_api( print(item.content, end="", flush=True) elif isinstance(item, ProvenanceEvent): - # Process provenance event immediately + # Use inline entity if available, otherwise fetch from graph prov_id = item.explain_id explain_graph = item.explain_graph or "urn:graph:retrieval" - entity = explain_client.fetch_entity( - prov_id, - graph=explain_graph, - user=user, - collection=collection - ) + entity = item.entity + if entity is None: + entity = explain_client.fetch_entity( + prov_id, + graph=explain_graph, + user=user, + collection=collection + ) if entity is None: if debug: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index 8849a206..c18c5bac 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -243,6 +243,7 @@ class PatternBase: content="", explain_id=session_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, )) async def emit_iteration_triples(self, flow, session_id, iteration_num, @@ -305,6 +306,7 @@ class PatternBase: content="", explain_id=iteration_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=iter_triples, )) async def emit_observation_triples(self, flow, session_id, iteration_num, @@ -360,6 +362,7 @@ class PatternBase: content="", explain_id=observation_entity_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=obs_triples, )) async def emit_final_triples(self, flow, session_id, iteration_num, @@ -416,6 +419,7 @@ class PatternBase: content="", explain_id=final_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=final_triples, )) # ---- Orchestrator provenance helpers ------------------------------------ @@ -437,6 +441,7 @@ class PatternBase: await respond(AgentResponse( chunk_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, )) async def emit_finding_triples( @@ -475,6 +480,7 @@ class PatternBase: await respond(AgentResponse( chunk_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, )) async def emit_plan_triples( @@ -494,6 +500,7 @@ class PatternBase: await respond(AgentResponse( chunk_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, )) async def emit_step_result_triples( @@ -526,6 +533,7 @@ class PatternBase: await respond(AgentResponse( chunk_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, )) async def emit_synthesis_triples( @@ -557,6 +565,7 @@ class PatternBase: await respond(AgentResponse( chunk_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, )) # ---- Response helpers --------------------------------------------------- diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 40857313..2c7423d8 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -473,6 +473,7 @@ class Processor(AgentService): content="", explain_id=session_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, )) logger.info(f"Question: {request.question}") @@ -640,6 +641,7 @@ class Processor(AgentService): content="", explain_id=iter_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=iter_triples, )) user_context = UserAwareContext(flow, request.user) @@ -717,6 +719,7 @@ class Processor(AgentService): content="", explain_id=final_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=final_triples, )) if streaming: @@ -793,6 +796,7 @@ class Processor(AgentService): content="", explain_id=observation_entity_uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=obs_triples, )) history.append(act) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index c0e55d84..3b281fe3 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -162,12 +162,13 @@ class Processor(FlowProcessor): triples=triples, )) - # Send explain ID and graph to response queue + # Send explain data to response queue await flow("response").send( DocumentRagResponse( response=None, explain_id=explain_id, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, message_type="explain", ), properties={"id": id} diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 85a7491e..abf10e90 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -253,12 +253,13 @@ class Processor(FlowProcessor): triples=triples, )) - # Send explain ID and graph to response queue + # Send explain data to response queue await flow("response").send( GraphRagResponse( message_type="explain", explain_id=explain_id, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, ), properties={"id": id} ) From c20e6540ecb624d24486981bc2c189e3db2a6176 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 7 Apr 2026 14:51:14 +0100 Subject: [PATCH 36/37] Subscriber resilience and RabbitMQ fixes (#765) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Subscriber resilience: recreate consumer after connection failure - Move consumer creation from Subscriber.start() into the run() loop, matching the pattern used by Consumer. If the connection drops and the consumer is closed in the finally block, the loop now recreates it on the next iteration instead of spinning forever on a None consumer. Consumer thread safety: - Dedicated ThreadPoolExecutor per consumer so all pika operations (create, receive, acknowledge, negative_acknowledge) run on the same thread — pika BlockingConnection is not thread-safe - Applies to both Consumer and Subscriber classes Config handler type audit — fix four mismatched type registrations: - librarian: was ["librarian"] (non-existent type), now ["flow", "active-flow"] (matches config["flow"] that the handler reads) - cores/service: was ["kg-core"], now ["flow"] (reads config["flow"]) - metering/counter: was ["token-costs"], now ["token-cost"] (singular) - agent/mcp_tool: was ["mcp-tool"], now ["mcp"] (reads config["mcp"]) Update tests --- .../test_subscriber_graceful_shutdown.py | 26 +++---- .../test_consumer_concurrency.py | 8 +-- .../trustgraph/base/async_processor.py | 1 - trustgraph-base/trustgraph/base/consumer.py | 70 +++++++++++++------ trustgraph-base/trustgraph/base/subscriber.py | 46 +++++++----- .../trustgraph/agent/mcp_tool/service.py | 2 +- trustgraph-flow/trustgraph/cores/service.py | 2 +- .../trustgraph/librarian/service.py | 5 +- .../trustgraph/metering/counter.py | 2 +- 9 files changed, 96 insertions(+), 66 deletions(-) diff --git a/tests/unit/test_base/test_subscriber_graceful_shutdown.py b/tests/unit/test_base/test_subscriber_graceful_shutdown.py index 0587e3d6..ec14f66b 100644 --- a/tests/unit/test_base/test_subscriber_graceful_shutdown.py +++ b/tests/unit/test_base/test_subscriber_graceful_shutdown.py @@ -61,23 +61,21 @@ async def test_subscriber_deferred_acknowledgment_success(): max_size=10, backpressure_strategy="block" ) - - # Start subscriber to initialize consumer - await subscriber.start() - + subscriber.consumer = mock_consumer + # Create queue for subscription queue = await subscriber.subscribe("test-queue") - + # Create mock message with matching queue name msg = create_mock_message("test-queue", {"data": "test"}) - + # Process message await subscriber._process_message(msg) - + # Should acknowledge successful delivery mock_consumer.acknowledge.assert_called_once_with(msg) mock_consumer.negative_acknowledge.assert_not_called() - + # Message should be in queue assert not queue.empty() received_msg = await queue.get() @@ -108,9 +106,7 @@ async def test_subscriber_dropped_message_still_acks(): max_size=1, # Very small queue backpressure_strategy="drop_new" ) - - # Start subscriber to initialize consumer - await subscriber.start() + subscriber.consumer = mock_consumer # Create queue and fill it queue = await subscriber.subscribe("test-queue") @@ -151,9 +147,7 @@ async def test_subscriber_orphaned_message_acks(): max_size=10, backpressure_strategy="block" ) - - # Start subscriber to initialize consumer - await subscriber.start() + subscriber.consumer = mock_consumer # Don't create any queues - message will be orphaned # This simulates a response arriving after the waiter has unsubscribed @@ -189,9 +183,7 @@ async def test_subscriber_backpressure_strategies(): max_size=2, backpressure_strategy="drop_oldest" ) - - # Start subscriber to initialize consumer - await subscriber.start() + subscriber.consumer = mock_consumer queue = await subscriber.subscribe("test-queue") diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 03244b73..59c7f2b5 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -81,9 +81,8 @@ class TestTaskGroupConcurrency: # Track how many consume_from_queue calls are made call_count = 0 - original_running = True - async def mock_consume(backend_consumer): + async def mock_consume(backend_consumer, executor=None): nonlocal call_count call_count += 1 # Wait a bit to let all tasks start, then signal stop @@ -107,7 +106,7 @@ class TestTaskGroupConcurrency: consumer = _make_consumer(concurrency=1) call_count = 0 - async def mock_consume(backend_consumer): + async def mock_consume(backend_consumer, executor=None): nonlocal call_count call_count += 1 await asyncio.sleep(0.01) @@ -294,9 +293,8 @@ class TestPollTimeout: raise type('Timeout', (Exception,), {})("timeout") mock_pulsar_consumer.receive = capture_receive - consumer.consumer = mock_pulsar_consumer - await consumer.consume_from_queue() + await consumer.consume_from_queue(mock_pulsar_consumer) assert received_kwargs.get("timeout_millis") == 100 diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index c805bffa..4f04df16 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -94,7 +94,6 @@ class AsyncProcessor: metrics = config_consumer_metrics, start_of_messages = False, - consumer_type = 'exclusive', ) self.running = True diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 4f8c9de5..b6c28bbe 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -12,6 +12,7 @@ import asyncio import time import logging +from concurrent.futures import ThreadPoolExecutor from .. exceptions import TooManyRequests @@ -110,29 +111,37 @@ class Consumer: logger.info(f"Starting {self.concurrency} receiver threads") # Create one backend consumer per concurrent task. - # Each gets its own connection — required for backends - # like RabbitMQ where connections are not thread-safe. + # Each gets its own connection and dedicated thread — + # required for backends like RabbitMQ where connections + # are not thread-safe (pika BlockingConnection must be + # used from a single thread). consumers = [] + executors = [] for i in range(self.concurrency): try: logger.info(f"Subscribing to topic: {self.topic} (worker {i})") - c = await asyncio.to_thread( - self.backend.create_consumer, - topic = self.topic, - subscription = self.subscriber, - schema = self.schema, - initial_position = initial_pos, - consumer_type = self.consumer_type, + executor = ThreadPoolExecutor(max_workers=1) + loop = asyncio.get_event_loop() + c = await loop.run_in_executor( + executor, + lambda: self.backend.create_consumer( + topic = self.topic, + subscription = self.subscriber, + schema = self.schema, + initial_position = initial_pos, + consumer_type = self.consumer_type, + ), ) consumers.append(c) + executors.append(executor) logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})") except Exception as e: logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True) raise async with asyncio.TaskGroup() as tg: - for c in consumers: - tg.create_task(self.consume_from_queue(c)) + for c, ex in zip(consumers, executors): + tg.create_task(self.consume_from_queue(c, ex)) if self.metrics: self.metrics.state("stopped") @@ -146,7 +155,10 @@ class Consumer: c.close() except Exception: pass + for ex in executors: + ex.shutdown(wait=False) consumers = [] + executors = [] await asyncio.sleep(self.reconnect_time) continue @@ -157,15 +169,18 @@ class Consumer: c.close() except Exception: pass + for ex in executors: + ex.shutdown(wait=False) - async def consume_from_queue(self, consumer): + async def consume_from_queue(self, consumer, executor=None): + loop = asyncio.get_event_loop() while self.running: try: - msg = await asyncio.to_thread( - consumer.receive, - timeout_millis=100 + msg = await loop.run_in_executor( + executor, + lambda: consumer.receive(timeout_millis=100), ) except Exception as e: # Handle timeout from any backend @@ -173,10 +188,11 @@ class Consumer: continue raise e - await self.handle_one_from_queue(msg, consumer) + await self.handle_one_from_queue(msg, consumer, executor) - async def handle_one_from_queue(self, msg, consumer): + async def handle_one_from_queue(self, msg, consumer, executor=None): + loop = asyncio.get_event_loop() expiry = time.time() + self.rate_limit_timeout # This loop is for retry on rate-limit / resource limits @@ -187,8 +203,11 @@ class Consumer: logger.warning("Gave up waiting for rate-limit retry") # Message failed to be processed, this causes it to - # be retried - consumer.negative_acknowledge(msg) + # be retried. Ack on the consumer's dedicated thread + # (pika is not thread-safe). + await loop.run_in_executor( + executor, lambda: consumer.negative_acknowledge(msg) + ) if self.metrics: self.metrics.process("error") @@ -210,8 +229,11 @@ class Consumer: logger.debug("Message processed successfully") - # Acknowledge successful processing of the message - consumer.acknowledge(msg) + # Acknowledge on the consumer's dedicated thread + # (pika is not thread-safe) + await loop.run_in_executor( + executor, lambda: consumer.acknowledge(msg) + ) if self.metrics: self.metrics.process("success") @@ -237,8 +259,10 @@ class Consumer: logger.error(f"Message processing exception: {e}", exc_info=True) # Message failed to be processed, this causes it to - # be retried - consumer.negative_acknowledge(msg) + # be retried. Ack on the consumer's dedicated thread. + await loop.run_in_executor( + executor, lambda: consumer.negative_acknowledge(msg) + ) if self.metrics: self.metrics.process("error") diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index 36948131..6cb234b1 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -7,6 +7,7 @@ import asyncio import time import logging import uuid +from concurrent.futures import ThreadPoolExecutor # Module logger logger = logging.getLogger(__name__) @@ -38,6 +39,7 @@ class Subscriber: self.pending_acks = {} # Track messages awaiting delivery self.consumer = None + self.executor = None def __del__(self): @@ -45,15 +47,6 @@ class Subscriber: async def start(self): - # Create consumer via backend - self.consumer = await asyncio.to_thread( - self.backend.create_consumer, - topic=self.topic, - subscription=self.subscription, - schema=self.schema, - consumer_type='exclusive', - ) - self.task = asyncio.create_task(self.run()) async def stop(self): @@ -80,6 +73,21 @@ class Subscriber: try: + # Create consumer and dedicated thread if needed + # (first run or after failure) + if self.consumer is None: + self.executor = ThreadPoolExecutor(max_workers=1) + loop = asyncio.get_event_loop() + self.consumer = await loop.run_in_executor( + self.executor, + lambda: self.backend.create_consumer( + topic=self.topic, + subscription=self.subscription, + schema=self.schema, + consumer_type='exclusive', + ), + ) + if self.metrics: self.metrics.state("running") @@ -128,9 +136,12 @@ class Subscriber: # Process messages only if not draining if not self.draining: try: - msg = await asyncio.to_thread( - self.consumer.receive, - timeout_millis=250 + loop = asyncio.get_event_loop() + msg = await loop.run_in_executor( + self.executor, + lambda: self.consumer.receive( + timeout_millis=250 + ), ) except Exception as e: # Handle timeout from any backend @@ -172,15 +183,18 @@ class Subscriber: except Exception: pass # Already closed or error self.consumer = None - - + + if self.executor: + self.executor.shutdown(wait=False) + self.executor = None + if self.metrics: self.metrics.state("stopped") if not self.running and not self.draining: return - - # If handler drops out, sleep a retry + + # Sleep before retry await asyncio.sleep(1) async def subscribe(self, id): diff --git a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py index 0bc5d7e3..c793f9ca 100755 --- a/trustgraph-flow/trustgraph/agent/mcp_tool/service.py +++ b/trustgraph-flow/trustgraph/agent/mcp_tool/service.py @@ -24,7 +24,7 @@ class Service(ToolService): **params ) - self.register_config_handler(self.on_mcp_config, types=["mcp-tool"]) + self.register_config_handler(self.on_mcp_config, types=["mcp"]) self.mcp_services = {} diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 3bc4e9b6..d6390805 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -108,7 +108,7 @@ class Processor(AsyncProcessor): flow_config = self, ) - self.register_config_handler(self.on_knowledge_config, types=["kg-core"]) + self.register_config_handler(self.on_knowledge_config, types=["flow"]) self.flows = {} diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index 15cc97fa..c735a550 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -246,7 +246,10 @@ class Processor(AsyncProcessor): taskgroup = self.taskgroup, ) - self.register_config_handler(self.on_librarian_config, types=["librarian"]) + self.register_config_handler( + self.on_librarian_config, + types=["flow", "active-flow"], + ) self.flows = {} diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index f120a812..3e0b610c 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -40,7 +40,7 @@ class Processor(FlowProcessor): } ) - self.register_config_handler(self.on_cost_config, types=["token-costs"]) + self.register_config_handler(self.on_cost_config, types=["token-cost"]) self.register_specification( ConsumerSpec( From e899370d988487d916af16de43d34e4079aee8eb Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 7 Apr 2026 22:24:59 +0100 Subject: [PATCH 37/37] Update docs for 2.2 release (#766) - Update protocol specs - Update protocol docs - Update API specs --- docs/api-gateway-changes-v1.8-to-v2.1.md | 108 --- docs/api.html | 176 ++++- docs/cli-changes-v1.8-to-v2.1.md | 112 --- docs/python-api.md | 683 +++--------------- docs/websocket.html | 142 +++- .../schemas/agent/AgentResponse.yaml | 6 + .../schemas/rag/DocumentRagResponse.yaml | 5 + .../schemas/rag/GraphRagResponse.yaml | 5 + specs/api/openapi.yaml | 6 +- specs/api/paths/flow/agent.yaml | 17 + specs/api/paths/flow/document-rag.yaml | 22 +- specs/api/paths/flow/graph-rag.yaml | 22 +- specs/websocket/asyncapi.yaml | 4 +- .../components/messages/ServiceRequest.yaml | 1 + .../messages/requests/SparqlQueryRequest.yaml | 46 ++ 15 files changed, 488 insertions(+), 867 deletions(-) delete mode 100644 docs/api-gateway-changes-v1.8-to-v2.1.md delete mode 100644 docs/cli-changes-v1.8-to-v2.1.md create mode 100644 specs/websocket/components/messages/requests/SparqlQueryRequest.yaml diff --git a/docs/api-gateway-changes-v1.8-to-v2.1.md b/docs/api-gateway-changes-v1.8-to-v2.1.md deleted file mode 100644 index 099dadb0..00000000 --- a/docs/api-gateway-changes-v1.8-to-v2.1.md +++ /dev/null @@ -1,108 +0,0 @@ -# API Gateway Changes: v1.8 to v2.1 - -## Summary - -The API gateway gained new WebSocket service dispatchers for embeddings -queries, a new REST streaming endpoint for document content, and underwent -a significant wire format change from `Value` to `Term`. The "objects" -service was renamed to "rows". - ---- - -## New WebSocket Service Dispatchers - -These are new request/response services available through the WebSocket -multiplexer at `/api/v1/socket` (flow-scoped): - -| Service Key | Description | -|-------------|-------------| -| `document-embeddings` | Queries document chunks by text similarity. Request/response uses `DocumentEmbeddingsRequest`/`DocumentEmbeddingsResponse` schemas. | -| `row-embeddings` | Queries structured data rows by text similarity on indexed fields. Request/response uses `RowEmbeddingsRequest`/`RowEmbeddingsResponse` schemas. | - -These join the existing `graph-embeddings` dispatcher (which was already -present in v1.8 but may have been updated). - -### Full list of WebSocket flow service dispatchers (v2.1) - -Request/response services (via `/api/v1/flow/{flow}/service/{kind}` or -WebSocket mux): - -- `agent`, `text-completion`, `prompt`, `mcp-tool` -- `graph-rag`, `document-rag` -- `embeddings`, `graph-embeddings`, `document-embeddings` -- `triples`, `rows`, `nlp-query`, `structured-query`, `structured-diag` -- `row-embeddings` - ---- - -## New REST Endpoint - -| Method | Path | Description | -|--------|------|-------------| -| `GET` | `/api/v1/document-stream` | Streams document content from the library as raw bytes. Query parameters: `user` (required), `document-id` (required), `chunk-size` (optional, default 1MB). Returns the document content in chunked transfer encoding, decoded from base64 internally. | - ---- - -## Renamed Service: "objects" to "rows" - -| v1.8 | v2.1 | Notes | -|------|------|-------| -| `objects_query.py` / `ObjectsQueryRequestor` | `rows_query.py` / `RowsQueryRequestor` | Schema changed from `ObjectsQueryRequest`/`ObjectsQueryResponse` to `RowsQueryRequest`/`RowsQueryResponse`. | -| `objects_import.py` / `ObjectsImport` | `rows_import.py` / `RowsImport` | Import dispatcher for structured data. | - -The WebSocket service key changed from `"objects"` to `"rows"`, and the -import dispatcher key similarly changed from `"objects"` to `"rows"`. - ---- - -## Wire Format Change: Value to Term - -The serialization layer (`serialize.py`) was rewritten to use the new `Term` -type instead of the old `Value` type. - -### Old format (v1.8 — `Value`) - -```json -{"v": "http://example.org/entity", "e": true} -``` - -- `v`: the value (string) -- `e`: boolean flag indicating whether the value is a URI - -### New format (v2.1 — `Term`) - -IRIs: -```json -{"t": "i", "i": "http://example.org/entity"} -``` - -Literals: -```json -{"t": "l", "v": "some text", "d": "datatype-uri", "l": "en"} -``` - -Quoted triples (RDF-star): -```json -{"t": "r", "r": {"s": {...}, "p": {...}, "o": {...}}} -``` - -- `t`: type discriminator — `"i"` (IRI), `"l"` (literal), `"r"` (quoted triple), `"b"` (blank node) -- Serialization now delegates to `TermTranslator` and `TripleTranslator` from `trustgraph.messaging.translators.primitives` - -### Other serialization changes - -| Field | v1.8 | v2.1 | -|-------|------|------| -| Metadata | `metadata.metadata` (subgraph) | `metadata.root` (simple value) | -| Graph embeddings entity | `entity.vectors` (plural) | `entity.vector` (singular) | -| Document embeddings chunk | `chunk.vectors` + `chunk.chunk` (text) | `chunk.vector` + `chunk.chunk_id` (ID reference) | - ---- - -## Breaking Changes - -- **`Value` to `Term` wire format**: All clients sending/receiving triples, embeddings, or entity contexts through the gateway must update to the new Term format. -- **`objects` to `rows` rename**: WebSocket service key and import key changed. -- **Metadata field change**: `metadata.metadata` (a serialized subgraph) replaced by `metadata.root` (a simple value). -- **Embeddings field changes**: `vectors` (plural) became `vector` (singular); document embeddings now reference `chunk_id` instead of inline `chunk` text. -- **New `/api/v1/document-stream` endpoint**: Additive, not breaking. diff --git a/docs/api.html b/docs/api.html index 7cbddd32..2a03a38b 100644 --- a/docs/api.html +++ b/docs/api.html @@ -422,7 +422,7 @@ data-styled.g138[id="sc-iJQrDi"]{content:"gtHWGb,"}/*!sc*/ -

TrustGraph API Gateway (2.1)

Download OpenAPI specification:

TrustGraph API Gateway (2.2)

Download OpenAPI specification:

REST API for TrustGraph - an AI-powered knowledge graph and RAG system.

Overview

  • AI services: agent, text-completion, prompt, RAG (document/graph)
  • Embeddings: embeddings, graph-embeddings, document-embeddings
  • -
  • Query: triples, rows, nlp-query, structured-query, row-embeddings
  • +
  • Query: triples, rows, nlp-query, structured-query, sparql-query, row-embeddings
  • Data loading: text-load, document-load
  • Utilities: mcp-tool, structured-diag
  • @@ -784,11 +784,26 @@ for processing and handled asynchronously.

    Stop ongoing library document processing.

    list-processing

    List current processing tasks and their status.

    -
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "add-document" "remove-document" "list-documents" "start-processing" "stop-processing" "list-processing"
    Authorizations:
    bearerAuth
    Request Body schema: application/json
    required
    operation
    required
    string
    Enum: "add-document" "remove-document" "list-documents" "get-document-metadata" "get-document-content" "stream-document" "add-child-document" "list-children" "begin-upload" "upload-chunk" "complete-upload" "abort-upload" "get-upload-status" "list-uploads" "start-processing" "stop-processing" "list-processing"

    Error response

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "What is the capital of France?",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "chunk-type": "thought",
    • "content": "I need to search for information about quantum computing",
    • "end-of-message": false,
    • "end-of-dialog": false
    }

    Document RAG - retrieve and generate from documents

    http://localhost:8088/api/v1/flow/{flow}/service/agent

    Request samples

    Content type
    application/json
    Example
    {
    • "question": "What is the capital of France?",
    • "user": "alice"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "chunk-type": "thought",
    • "content": "I need to search for information about quantum computing",
    • "end-of-message": false,
    • "end-of-dialog": false
    }

    Document RAG - retrieve and generate from documents

    Streaming

    Enable streaming: true to receive the answer as it's generated:

      -
    • Multiple messages with response content
    • +
    • Multiple chunk messages with response content
    • +
    • explain messages with inline provenance triples (explain_triples)
    • Final message with end-of-stream: true
    • +
    • Session ends with end_of_session: true
    +

    Explain events carry explain_id, explain_graph, and explain_triples +inline in the stream, so no follow-up knowledge graph query is needed.

    Without streaming, returns complete answer in single response.

    Parameters

      @@ -1216,7 +1256,7 @@ Each step has: thought, action, arguments, observation.

      " class="sc-iKGpAq sc-cCYyou sc-cjERFZ dXXcln fTBBlJ dkmSdy">

      Error response

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What are the key findings in the research papers?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "The research papers present three key findings:\n1. Quantum entanglement exhibits non-local correlations\n2. Bell's inequality is violated in experimental tests\n3. Applications in quantum cryptography are promising\n",
    • "end-of-stream": false
    }

    Graph RAG - retrieve and generate from knowledge graph

    http://localhost:8088/api/v1/flow/{flow}/service/document-rag

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What are the key findings in the research papers?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "The research papers present three key findings:\n1. Quantum entanglement exhibits non-local correlations\n2. Bell's inequality is violated in experimental tests\n3. Applications in quantum cryptography are promising\n",
    • "end-of-stream": false
    }

    Graph RAG - retrieve and generate from knowledge graph

    Streaming

    Enable streaming: true to receive the answer as it's generated:

      -
    • Multiple messages with response content
    • +
    • Multiple chunk messages with response content
    • +
    • explain messages with inline provenance triples (explain_triples)
    • Final message with end-of-stream: true
    • +
    • Session ends with end_of_session: true
    +

    Explain events carry explain_id, explain_graph, and explain_triples +inline in the stream, so no follow-up knowledge graph query is needed.

    Without streaming, returns complete answer in single response.

    Parameters

    Control retrieval scope with multiple knobs:

    @@ -1332,7 +1380,7 @@ Each step has: thought, action, arguments, observation.

    " class="sc-iKGpAq sc-cCYyou sc-cjERFZ dXXcln fTBBlJ dkmSdy">

    Error response

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What connections exist between quantum physics and computer science?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "Quantum physics and computer science intersect primarily through quantum computing.\nThe knowledge graph shows connections through:\n- Quantum algorithms (Shor's algorithm, Grover's algorithm)\n- Quantum information theory\n- Computational complexity theory\n",
    • "end-of-stream": false
    }

    Text completion - direct LLM generation

    http://localhost:8088/api/v1/flow/{flow}/service/graph-rag

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "What connections exist between quantum physics and computer science?",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    Example
    {
    • "response": "Quantum physics and computer science intersect primarily through quantum computing.\nThe knowledge graph shows connections through:\n- Quantum algorithms (Shor's algorithm, Grover's algorithm)\n- Quantum information theory\n- Computational complexity theory\n",
    • "end-of-stream": false
    }

    Text completion - direct LLM generation

    Text Load Overview

    Fire-and-forget document loading:

      -
    • Input: Text content (base64 encoded)
    • +
    • Input: Text content (raw UTF-8 or base64 encoded)
    • Process: Chunk, embed, store
    • Output: None (202 Accepted)
    @@ -2762,7 +2815,12 @@ encoded <span class="token operator">=</span> base64<sp

    Pipeline runs asynchronously after request returns.

    Text Format

    -

    Text must be base64 encoded:

    +

    Text may be sent as raw UTF-8 text:

    +
    {
    +  "text": "Cancer survival: 2.74× higher hazard ratio"
    +}
    +
    +

    Older clients may still send base64 encoded text:

    text_content = "This is the document..."
     encoded = base64.b64encode(text_content.encode('utf-8'))
     
    @@ -2792,8 +2850,8 @@ encoded = base64
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    -
    Request Body schema: application/json
    required
    text
    required
    string <byte>

    Text content (base64 encoded)

    +
    Request Body schema: application/json
    required
    text
    required
    string

    Text content, either raw text or base64 encoded for compatibility with older clients

    id
    string

    Document identifier

    user
    string
    Default: "trustgraph"
    = base64

    Error response

    Request samples

    Content type
    application/json
    Example
    {
    • "text": "VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg==",
    • "id": "doc-123",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    Document Load - load binary documents (PDF, etc.)

    http://localhost:8088/api/v1/flow/{flow}/service/text-load

    Request samples

    Content type
    application/json
    Example
    {
    • "text": "This is the document text...",
    • "id": "doc-123",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    Document Load - load binary documents (PDF, etc.)

    = base64

    Error response

    Request samples

    Content type
    application/json
    Example
    {
    • "data": "JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg==",
    • "id": "doc-789",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    Import/Export

    http://localhost:8088/api/v1/flow/{flow}/service/document-load

    Request samples

    Content type
    application/json
    Example
    {
    • "data": "JVBERi0xLjQKJeLjz9MKMSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PmVuZG9iagoyIDAgb2JqCjw8L1R5cGUvUGFnZXMvS2lkc1szIDAgUl0vQ291bnQgMT4+ZW5kb2JqCg==",
    • "id": "doc-789",
    • "user": "alice",
    • "collection": "research"
    }

    Response samples

    Content type
    application/json
    { }

    SPARQL query - execute SPARQL 1.1 queries against the knowledge graph

    Execute a SPARQL 1.1 query against the knowledge graph.

    +

    Supported Query Types

    +
      +
    • SELECT: Returns variable bindings as a table of results
    • +
    • ASK: Returns true/false for existence checks
    • +
    • CONSTRUCT: Returns a set of triples built from a template
    • +
    • DESCRIBE: Returns triples describing matched resources
    • +
    +

    SPARQL Features

    +

    Supports standard SPARQL 1.1 features including:

    +
      +
    • Basic Graph Patterns (BGPs) with triple pattern matching
    • +
    • OPTIONAL, UNION, FILTER
    • +
    • BIND, VALUES
    • +
    • ORDER BY, LIMIT, OFFSET, DISTINCT
    • +
    • GROUP BY with aggregates (COUNT, SUM, AVG, MIN, MAX, GROUP_CONCAT)
    • +
    • Built-in functions (isIRI, STR, REGEX, CONTAINS, etc.)
    • +
    +

    Query Examples

    +

    Find all entities of a type:

    +
    SELECT ?s ?label WHERE {
    +  ?s <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://example.com/Person> .
    +  ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label .
    +}
    +LIMIT 10
    +
    +

    Check if an entity exists:

    +
    ASK { <http://example.com/alice> ?p ?o }
    +
    +
    Authorizations:
    bearerAuth
    path Parameters
    flow
    required
    string
    Example: my-flow

    Flow instance ID

    +
    Request Body schema: application/json
    required
    query
    required
    string

    SPARQL 1.1 query string

    +
    user
    string
    Default: "trustgraph"

    User/keyspace identifier

    +
    collection
    string
    Default: "default"

    Collection identifier

    +
    limit
    integer
    Default: 10000

    Safety limit on number of results

    +

    Responses

    Request samples

    Content type
    application/json
    Example
    {
    • "query": "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10",
    • "user": "trustgraph",
    • "collection": "default"
    }

    Response samples

    Content type
    application/json
    Example
    {}

    Import/Export

    Bulk data import and export

    Stream document content from library

    Local development server

    http://localhost:8088/api/metrics/{path}

    Response samples

    Content type
    application/json
    {
    • "error": "Unauthorized"
    }