Compare commits

...

38 commits

Author SHA1 Message Date
Cyber MacGeddon
dc72ed3cca Merge branch 'release/v2.2' 2026-04-07 22:29:55 +01:00
cybermaggedon
e899370d98
Update docs for 2.2 release (#766)
- Update protocol specs
- Update protocol docs
- Update API specs
2026-04-07 22:24:59 +01:00
cybermaggedon
c20e6540ec
Subscriber resilience and RabbitMQ fixes (#765)
Subscriber resilience: recreate consumer after connection failure

- Move consumer creation from Subscriber.start() into the run() loop,
  matching the pattern used by Consumer. If the connection drops and the
  consumer is closed in the finally block, the loop now recreates it on
  the next iteration instead of spinning forever on a None consumer.

Consumer thread safety:
- Dedicated ThreadPoolExecutor per consumer so all pika operations
  (create, receive, acknowledge, negative_acknowledge) run on the
  same thread — pika BlockingConnection is not thread-safe
- Applies to both Consumer and Subscriber classes

Config handler type audit — fix four mismatched type registrations:
- librarian: was ["librarian"] (non-existent type), now ["flow",
  "active-flow"] (matches config["flow"] that the handler reads)
- cores/service: was ["kg-core"], now ["flow"] (reads
  config["flow"])
- metering/counter: was ["token-costs"], now ["token-cost"]
  (singular)
- agent/mcp_tool: was ["mcp-tool"], now ["mcp"] (reads
  config["mcp"])

Update tests
2026-04-07 14:51:14 +01:00
cybermaggedon
ddd4bd7790
Deliver explainability triples inline in retrieval response stream (#763)
Provenance triples are now included directly in explain messages from
GraphRAG, DocumentRAG, and Agent services, eliminating the need for
follow-up knowledge graph queries to retrieve explainability details.

Each explain message in the response stream now carries:
- explain_id: root URI for this provenance step (unchanged)
- explain_graph: named graph where triples are stored (unchanged)
- explain_triples: the actual provenance triples for this step (new)

Changes across the stack:
- Schema: added explain_triples field to GraphRagResponse,
  DocumentRagResponse, and AgentResponse
- Services: all explain message call sites pass triples through
  (graph_rag, document_rag, agent react, agent orchestrator)
- Translators: encode explain_triples via TripleTranslator for
  gateway wire format
- Python SDK: ProvenanceEvent now includes parsed ExplainEntity
  and raw triples; expanded event_type detection
- CLI: invoke_graph_rag, invoke_agent, invoke_document_rag use
  inline entity when available, fall back to graph query
- Tech specs updated

Additional explainability test
2026-04-07 12:19:05 +01:00
cybermaggedon
2f8d6a3ffb
Fix agent config handler registration, remove debug prints, disable RabbitMQ heartbeats (#764)
- Fix agent react and orchestrator services appending bare methods
  to config_handlers instead of using register_config_handler() —
  caused 'method object is not subscriptable' on config notify
- Add exc_info to config fetch retry logging for proper tracebacks
- Remove debug print statements from collection management
  dispatcher and translator
- Disable RabbitMQ heartbeats (heartbeat=0) to prevent broker
  closing idle producer connections that can't process heartbeat
  frames from BlockingConnection
2026-04-07 12:11:12 +01:00
Sreeram Venkatasubramanian
f0c9039b76 fix: reduce consumer poll timeout from 2000ms to 100ms 2026-04-07 12:02:27 +01:00
cybermaggedon
4acd853023
Config push notify pattern: replace stateful pub/sub with signal+ fetch (#760)
Replace the config push mechanism that broadcast the full config
blob on a 'state' class pub/sub queue with a lightweight notify
signal containing only the version number and affected config
types. Processors fetch the full config via request/response from
the config service when notified.

This eliminates the need for the pub/sub 'state' queue class and
stateful pub/sub services entirely. The config push queue moves
from 'state' to 'flow' class — a simple transient signal rather
than a retained message.  This solves the RabbitMQ
late-subscriber problem where restarting processes never received
the current config because their fresh queue had no historical
messages.

Key changes:
- ConfigPush schema: config dict replaced with types list
- Subscribe-then-fetch startup with retry: processors subscribe
  to notify queue, fetch config via request/response, then
  process buffered notifies with version comparison to avoid race
  conditions
- register_config_handler() accepts optional types parameter so
  handlers only fire when their config types change
- Short-lived config request/response clients to avoid subscriber
  contention on non-persistent response topics
- Config service passes affected types through put/delete/flow
  operations
- Gateway ConfigReceiver rewritten with same notify pattern and
  retry loop

Tests updated

New tests:
- register_config_handler: without types, with types, multiple
  types, multiple handlers
- on_config_notify: old/same version skipped, irrelevant types
  skipped (version still updated), relevant type triggers fetch,
  handler without types always called, mixed handler filtering,
  empty types invokes all, fetch failure handled gracefully
- fetch_config: returns config+version, raises on error response,
  stops client even on exception
- fetch_and_apply_config: applies to all handlers on startup,
  retries on failure
2026-04-06 16:57:27 +01:00
V.Sreeram
d4723566cb fix: prevent duplicate dispatcher creation race condition in invoke_global_service (#715)
* fix: prevent duplicate dispatcher creation race condition in invoke_global_service

Concurrent coroutines could all pass the `if key in self.dispatchers` check
before any of them wrote the result back, because `await dispatcher.start()`
yields to the event loop. This caused multiple Pulsar consumers to be created
on the same shared subscription, distributing responses round-robin and
dropping ~2/3 of them — manifesting as a permanent spinner in the Workbench UI.

Apply a double-checked asyncio.Lock in both `invoke_global_service` and
`invoke_flow_service` so only one dispatcher is ever created per service key.

* test: add concurrent-dispatch tests for race condition fix

Add asyncio.gather-based tests that verify invoke_global_service and
invoke_flow_service create exactly one dispatcher under concurrent calls,
preventing the duplicate Pulsar consumer bug.
2026-04-06 11:14:32 +01:00
Alex Jenkins
10a931f04c Feat: Auto-pull missing Ollama models (#757)
* fix deadlink in readme

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>

* feat: Auto-pull Ollama models

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>

* fix: Restore namespace __init__.py files for package resolution

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>

* fix CI

Signed-off-by: Jenkins, Kenneth Alexander <kjenkins60@gatech.edu>
2026-04-06 11:10:53 +01:00
cybermaggedon
ee65d90fdd
SPARQL service supports batching/streaming (#755) 2026-04-02 17:54:07 +01:00
cybermaggedon
d9dc4cbab5
SPARQL query service (#754)
SPARQL 1.1 query service wrapping pub/sub triples interface

Add a backend-agnostic SPARQL query service that parses SPARQL
queries using rdflib, decomposes them into triple pattern lookups
via the existing TriplesClient pub/sub interface, and performs
in-memory joins, filters, and projections.

Includes:
- SPARQL parser, algebra evaluator, expression evaluator, solution
  sequence operations (BGP, JOIN, OPTIONAL, UNION, FILTER, BIND,
  VALUES, GROUP BY, ORDER BY, LIMIT/OFFSET, DISTINCT, aggregates)
- FlowProcessor service with TriplesClientSpec
- Gateway dispatcher, request/response translators, API spec
- Python SDK method (FlowInstance.sparql_query)
- CLI command (tg-invoke-sparql-query)
- Tech spec (docs/tech-specs/sparql-query.md)

New unit tests for SPARQL query
2026-04-02 17:21:39 +01:00
cybermaggedon
62c30a3a50
Skip Pulsar check in tg-verify-system-status (#753) 2026-04-02 13:20:39 +01:00
cybermaggedon
24f0190ce7
RabbitMQ pub/sub backend with topic exchange architecture (#752)
Adds a RabbitMQ backend as an alternative to Pulsar, selectable via
PUBSUB_BACKEND=rabbitmq. Both backends implement the same PubSubBackend
protocol — no application code changes needed to switch.

RabbitMQ topology:
- Single topic exchange per topicspace (e.g. 'tg')
- Routing key derived from queue class and topic name
- Shared consumers: named queue bound to exchange (competing, round-robin)
- Exclusive consumers: anonymous auto-delete queue (broadcast, each gets
  every message). Used by Subscriber and config push consumer.
- Thread-local producer connections (pika is not thread-safe)
- Push-based consumption via basic_consume with process_data_events
  for heartbeat processing

Consumer model changes:
- Consumer class creates one backend consumer per concurrent task
  (required for pika thread safety, harmless for Pulsar)
- Consumer class accepts consumer_type parameter
- Subscriber passes consumer_type='exclusive' for broadcast semantics
- Config push consumer uses consumer_type='exclusive' so every
  processor instance receives config updates
- handle_one_from_queue receives consumer as parameter for correct
  per-connection ack/nack

LibrarianClient:
- New shared client class replacing duplicated librarian request-response
  code across 6+ services (chunking, decoders, RAG, etc.)
- Uses stream-document instead of get-document-content for fetching
  document content in 1MB chunks (avoids broker message size limits)
- Standalone object (self.librarian = LibrarianClient(...)) not a mixin
- get-document-content marked deprecated in schema and OpenAPI spec

Serialisation:
- Extracted dataclass_to_dict/dict_to_dataclass to shared
  serialization.py (used by both Pulsar and RabbitMQ backends)

Librarian queues:
- Changed from flow class (persistent) back to request/response class
  now that stream-document eliminates large single messages
- API upload chunk size reduced from 5MB to 3MB to stay under broker
  limits after base64 encoding

Factory and CLI:
- get_pubsub() handles 'rabbitmq' backend with RabbitMQ connection params
- add_pubsub_args() includes RabbitMQ options (host, port, credentials)
- add_pubsub_args(standalone=True) defaults to localhost for CLI tools
- init_trustgraph skips Pulsar admin setup for non-Pulsar backends
- tg-dump-queues and tg-monitor-prompts use backend abstraction
- BaseClient and ConfigClient accept generic pubsub config
2026-04-02 12:47:16 +01:00
cybermaggedon
4fb0b4d8e8
Pub/sub abstraction: decouple from Pulsar (#751)
Remove Pulsar-specific concepts from application code so that
the pub/sub backend is swappable via configuration.

Rename translators:
- to_pulsar/from_pulsar → decode/encode across all translator
  classes, dispatch handlers, and tests (55+ files)
- from_response_with_completion → encode_with_completion
- Remove pulsar.schema.Record from translator base class

Queue naming (CLASS:TOPICSPACE:TOPIC):
- Replace topic() helper with queue() using new format:
  flow:tg:name, request:tg:name, response:tg:name, state:tg:name
- Queue class implies persistence/TTL (no QoS in names)
- Update Pulsar backend map_topic() to parse new format
- Librarian queues use flow class (persistent, for chunking)
- Config push uses state class (persistent, last-value)
- Remove 15 dead topic imports from schema files
- Update init_trustgraph.py namespace: config → state

Confine Pulsar to pulsar_backend.py:
- Delete legacy PulsarClient class from pubsub.py
- Move add_args to add_pubsub_args() with standalone flag
  for CLI tools (defaults to localhost)
- PulsarBackendConsumer.receive() catches _pulsar.Timeout,
  raises standard TimeoutError
- Remove Pulsar imports from: async_processor, flow_processor,
  log_level, all 11 client files, 4 storage writers, gateway
  service, gateway config receiver
- Remove log_level/LoggerLevel from client API
- Rewrite tg-monitor-prompts to use backend abstraction
- Update tg-dump-queues to use add_pubsub_args

Also: pubsub-abstraction.md tech spec covering problem statement,
design goals, as-is requirements, candidate broker assessment,
approach, and implementation order.
2026-04-01 20:16:53 +01:00
cybermaggedon
dbf8daa74a Additional agent DAG tests (#750)
- test_agent_provenance.py: test_session_parent_uri,
  test_session_no_parent_uri, and 6 synthesis tests (types,
  single/multiple parents, document, label)
- test_on_action_callback.py: 3 tests — fires before tool, skipped
  for Final, works when None
- test_callback_message_id.py: 7 tests — message_id on think/observe/
  answer callbacks (streaming + non-streaming) and
  send_final_response
- test_parse_chunk_message_id.py (5 tests) - _parse_chunk propagates
  message_id for thought, observation, answer; handles missing
  gracefully
- test_explainability_parsing.py (+1) -
  test_dispatches_analysis_with_tooluse - Analysis+ToolUse mixin still
  dispatches to Analysis
- test_explainability.py (+1) -
  test_observation_found_via_subtrace_synthesis
- chain walker follows from sub-trace Synthesis to find Observation
  and
  Conclusion in correct order
- test_agent_provenance.py (+8) - session parent_uri (2), synthesis
  single/multiple parents, types, document, label (6)
2026-04-01 13:59:58 +01:00
cybermaggedon
3ba6a3238f
Misc test harnesses (#749)
Some misc test harnesses for a few features
2026-04-01 13:52:28 +01:00
cybermaggedon
2bcf375103
Wire message_id on all answer chunks, fix DAG structure (#748)
Wire message_id on all answer chunks, fix DAG structure message_id:
- Add message_id to AgentAnswer dataclass and propagate in
  socket_client._parse_chunk
- Wire message_id into answer callbacks and send_final_response
  for all three patterns (react, plan-then-execute, supervisor)
- Supervisor decomposition thought and synthesis answer chunks
  now carry message_id

DAG structure fixes:
- Observation derives from sub-trace Synthesis (not Analysis)
  when a tool produces a sub-trace; tracked via
  last_sub_explain_uri on context
- Subagent sessions derive from parent's Decomposition via
  parent_uri on agent_session_triples
- Findings derive from subagent Conclusions (not Decomposition)
- Synthesis derives from all findings (multiple wasDerivedFrom)
  ensuring single terminal node
- agent_synthesis_triples accepts list of parent URIs
- Explainability chain walker follows from sub-trace terminal
  to find downstream Observation

Emit Analysis before tool execution:
- Add on_action callback to react() in agent_manager.py, called
  after reason() but before tool invocation
- Orchestrator and old service emit Analysis+ToolUse triples via
  on_action so sub-traces appear after their parent in the stream
2026-04-01 13:27:41 +01:00
cybermaggedon
153ae9ad30
Split Analysis into Analysis+ToolUse and Observation, add message_id (#747)
Refactor agent provenance so that the decision (thought + tool
selection) and the result (observation) are separate DAG entities:

  Question ← Analysis+ToolUse ← Observation ← ... ← Conclusion

Analysis gains tg:ToolUse as a mixin RDF type and is emitted
before tool execution via an on_action callback in react().
This ensures sub-traces (e.g. GraphRAG) appear after their
parent Analysis in the streaming event order.

Observation becomes a standalone prov:Entity with tg:Observation
type, emitted after tool execution. The linear DAG chain runs
through Observation — subsequent iterations and the Conclusion
derive from it, not from the Analysis.

message_id is populated on streaming AgentResponse for thought
and observation chunks, using the provenance URI of the entity
being built. This lets clients group streamed chunks by entity.

Wire changes:
- provenance/agent.py: Add ToolUse type, new
  agent_observation_triples(), remove observation from iteration
- agent_manager.py: Add on_action callback between reason() and
  tool execution
- orchestrator/pattern_base.py: Split emit, wire message_id,
  chain through observation URIs
- orchestrator/react_pattern.py: Emit Analysis via on_action
  before tool runs
- agent/react/service.py: Same for non-orchestrator path
- api/explainability.py: New Observation class, updated dispatch
  and chain walker
- api/types.py: Add message_id to AgentThought/AgentObservation
- cli: Render Observation separately, [analysis: tool] labels
2026-03-31 17:51:22 +01:00
cybermaggedon
89e13a756a
Minor agent-orchestrator updates (#746)
Tidy agent-orchestrator logs

Added CLI support for selecting the pattern...

  tg-invoke-agent -q "What is the document about?" -p supervisor -v
  tg-invoke-agent -q "What is the document about?" -p plan-then-execute -v
  tg-invoke-agent -q "What is the document about?" -p react -v

Added new event types to tg-show-explain-trace
2026-03-31 13:29:04 +01:00
cybermaggedon
816a8cfcf6
Update tests for agent-orchestrator (#745)
Add 96 tests covering the orchestrator's aggregation, provenance,
routing, and explainability parsing. These verify the supervisor
fan-out/fan-in lifecycle, the new RDF provenance types
(Decomposition, Finding, Plan, StepResult, Synthesis), and their
round-trip through the wire format.

Unit tests (84):
- Aggregator: register, record completion, peek, build synthesis,
  cleanup
- Provenance triple builders: types, provenance links,
  goals/steps, labels
- Explainability parsing: from_triples dispatch, field extraction
  for all new entity types, precedence over existing types
- PatternBase: is_subagent detection, emit_subagent_completion
  message shape
- Completion dispatch: detection logic, full aggregator
  integration flow, synthesis request not re-intercepted as
  completion
- MetaRouter: task type identification, pattern selection,
  valid_patterns constraints, fallback on LLM error or unknown
  response

Contract tests (12):
- Orchestration fields on AgentRequest round-trip correctly
- subagent-completion and synthesise step types in request
  history
- Plan steps with status and dependencies
- Provenance triple builder → wire format → from_triples
  round-trip for all five new entity types
2026-03-31 13:12:26 +01:00
cybermaggedon
7b734148b3
agent-orchestrator: add explainability provenance for all patterns (#744)
agent-orchestrator: add explainability provenance for all agent
patterns

Extend the provenance/explainability system to provide
human-readable reasoning traces for the orchestrator's three
agent patterns. Previously only ReAct emitted provenance
(session, iteration, conclusion). Now each pattern records its
cognitive steps as typed RDF entities in the knowledge graph,
using composable mixin types (e.g. Finding + Answer).

New provenance chains:
- Supervisor: Question → Decomposition → Finding ×N → Synthesis
- Plan-then-Execute: Question → Plan → StepResult ×N → Synthesis
- ReAct: Question → Analysis ×N → Conclusion (unchanged)

New RDF types: Decomposition, Finding, Plan, StepResult.
New predicates: tg:subagentGoal, tg:planStep.
Reuses existing Synthesis + Answer mixin for final answers.

Provenance library (trustgraph-base):
- Triple builders, URI generators, vocabulary labels for new types
- Client dataclasses with from_triples() dispatch
- fetch_agent_trace() follows branching provenance chains
- API exports updated

Orchestrator (trustgraph-flow):
- PatternBase emit methods for decomposition, finding, plan, step result, and synthesis
- SupervisorPattern emits decomposition during fan-out
- PlanThenExecutePattern emits plan and step results
- Service emits finding triples on subagent completion
- Synthesis provenance replaces generic final triples

CLI (trustgraph-cli):
- invoke_agent -x displays new entity types inline
2026-03-31 12:54:51 +01:00
cybermaggedon
e65ea217a2
agent-orchestrator improvements (#743)
agent-orchestrator improvements:
- Improve agent trace
- Improve queue dumping
- Fixing supervisor pattern
- Fix synthesis step to remove loop

Minor dev environment improvements:
- Improve queue dump output for JSON
- Reduce dev container rebuild
2026-03-31 11:24:30 +01:00
cybermaggedon
81ca7bbc11
Change monitor default to prompts-rag (#742) 2026-03-31 09:35:58 +01:00
cybermaggedon
0781d3e6a7
Remove unnecessary prompt-client logging (#740) 2026-03-31 09:12:33 +01:00
cybermaggedon
849987f0e6
Add multi-pattern orchestrator with plan-then-execute and supervisor (#739)
Introduce an agent orchestrator service that supports three
execution patterns (ReAct, plan-then-execute, supervisor) with
LLM-based meta-routing to select the appropriate pattern and task
type per request. Update the agent schema to support
orchestration fields (correlation, sub-agents, plan steps) and
remove legacy response fields (answer, thought, observation).
2026-03-31 00:32:49 +01:00
CommitHu502Craft
7af1d60db8 fix(gateway): accept raw utf-8 text in text-load (#729)
Co-authored-by: nanqinhu <139929317+nanqinhu@users.noreply.github.com>
2026-03-30 17:00:10 +01:00
cybermaggedon
5a9db2da50
Add tg-monitor-prompts CLI tool for prompt queue monitoring (#737)
Subscribes to prompt request/response Pulsar queues, correlates
messages by ID, and logs a summary with template name, truncated
terms, and elapsed time. Streaming responses are accumulated and
shown at completion. Supports prompt and prompt-rag queue types.
2026-03-30 16:08:46 +01:00
cybermaggedon
687a9e08fe
master -> release/v2.2 (#732) 2026-03-29 20:26:26 +01:00
cybermaggedon
413f917676 Add missing pdf extra to unstructured dependency (#728)
* Fix PDF processing deps so that PDF processing works
2026-03-29 20:22:45 +01:00
cybermaggedon
20204d87c3
Fix OpenAI compatibility issues for newer models and Azure config (#727)
Use max_completion_tokens for OpenAI and Azure OpenAI providers:
The OpenAI API deprecated max_tokens in favor of
max_completion_tokens for chat completions. Newer models
(gpt-4o, o1, o3) reject the old parameter with a 400 error.

AZURE_API_VERSION env var now overrides the default API version:
(falls back to 2024-12-01-preview).

Update tests to test for expected structures
2026-03-28 11:19:45 +00:00
cybermaggedon
a634520509
Fix websocket error responses in Mux dispatcher (#726)
Error responses from the websocket multiplexer were missing the
request ID and using a bare string format instead of the structured
error protocol. This caused clients to hang when a request failed
(e.g. unsupported service for a flow) because the error could not
be routed to the waiting caller.

Include request ID in all error paths, use structured error format
({message, type}) with complete flag, and extract the ID early in
receive() so even malformed requests get a routable error when
possible.

Updated tests - tests were coded against invalid protocol messages
2026-03-28 10:58:28 +00:00
cybermaggedon
ea33620fb2
Fix missing auth header in verify_system_status (#724)
Fix missing auth header in verify_system_status processor check               
                                                                             
The check_processors function received the token parameter but                
did not include it in the Authorization header when calling the               
metrics endpoint, causing 401 errors when gateway auth is enabled.
2026-03-26 16:58:30 +00:00
cybermaggedon
9c55a0a0ff
Persistent websocket connections for socket clients and CLI tools (#723)
Replace per-request websocket connections in SocketClient and
AsyncSocketClient with a single persistent connection that
multiplexes requests by ID via a background reader task. This
eliminates repeated TCP+WS handshakes which caused significant
latency over proxies.

Convert show_flows, show_flow_blueprints, and
show_parameter_types CLI tools from sequential HTTP requests to
concurrent websocket requests using AsyncSocketClient, reducing
round trips from O(N) sequential to a small number of parallel
batches.

Also fix describe_interfaces bug in show_flows where response
queue was reading the request field instead of the response
field.
2026-03-26 16:46:28 +00:00
cybermaggedon
1ec081f42f
Update CLA notice in repo (#722) 2026-03-26 14:18:13 +00:00
Cyber MacGeddon
f02bbdb442 New CLA workflow: Uses a github action in
trustgraph-ai/contributor-license-agreement

This blocks a PR until the commiter responds with a message
of agreement with the CLA terms.
2026-03-26 14:09:07 +00:00
cybermaggedon
4164ef1c47
Add GATEWAY_SECRET support for MCP server to API gateway auth (#721)
Pass bearer token from GATEWAY_SECRET environment variable as a
URL query parameter on websocket connections to the API gateway.
When unset or empty, no auth is applied (backwards compatible).
2026-03-26 10:49:28 +00:00
cybermaggedon
97f5645ea0
CLA (#716)
Explanatory text for the CLA process
2026-03-26 09:08:09 +00:00
cybermaggedon
1f67fc2312
master -> release/v2.2 (#713)
Merge doc updates from master into release branch
2026-03-25 17:53:20 +00:00
260 changed files with 16757 additions and 4051 deletions

View file

@ -77,8 +77,8 @@ some-containers:
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
${DOCKER} build -f containers/Containerfile.unstructured \
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.unstructured \
# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.mcp \

237
dev-tools/library_client.py Normal file
View file

@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""
Client utility for browsing and loading documents from the TrustGraph
public document library.
Usage:
python library_client.py list
python library_client.py search <text>
python library_client.py load-all
python library_client.py load-doc <id>
python library_client.py load-match <text>
"""
import json
import urllib.request
import sys
import os
import argparse
from trustgraph.api import Api
from trustgraph.api.types import Uri, Literal, Triple
BUCKET_URL = "https://storage.googleapis.com/trustgraph-library"
INDEX_URL = f"{BUCKET_URL}/index.json"
default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
default_user = "trustgraph"
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def fetch_index():
with urllib.request.urlopen(INDEX_URL) as resp:
return json.loads(resp.read())
def fetch_document_metadata(doc_id):
url = f"{BUCKET_URL}/{doc_id}.json"
with urllib.request.urlopen(url) as resp:
return json.loads(resp.read())
def fetch_document_content(doc_id):
url = f"{BUCKET_URL}/{doc_id}.epub"
with urllib.request.urlopen(url) as resp:
return resp.read()
def search_index(index, query):
query = query.lower()
results = []
for doc in index:
title = doc.get("title", "").lower()
comments = doc.get("comments", "").lower()
tags = [t.lower() for t in doc.get("tags", [])]
if (query in title or query in comments or
any(query in t for t in tags)):
results.append(doc)
return results
def print_index(index):
if not index:
return
# Calculate column widths
id_width = max(len(str(doc.get("id", ""))) for doc in index)
title_width = max(len(doc.get("title", "")) for doc in index)
# Cap title width for readability
title_width = min(title_width, 60)
id_width = max(id_width, 2)
try:
term_width = os.get_terminal_size().columns
except OSError:
term_width = 120
tags_width = max(term_width - id_width - title_width - 6, 20)
header = f"{'ID':<{id_width}} {'Title':<{title_width}} {'Tags':<{tags_width}}"
print(header)
print("-" * len(header))
for doc in index:
eid = str(doc.get("id", ""))
title = doc.get("title", "")
if len(title) > title_width:
title = title[:title_width - 3] + "..."
tags = ", ".join(doc.get("tags", []))
if len(tags) > tags_width:
tags = tags[:tags_width - 3] + "..."
print(f"{eid:<{id_width}} {title:<{title_width}} {tags}")
def convert_value(v):
"""Convert a JSON triple value to a Uri or Literal."""
if v["type"] == "uri":
return Uri(v["value"])
else:
return Literal(v["value"])
def convert_metadata(metadata_json):
"""Convert JSON metadata triples to Triple objects."""
triples = []
for t in metadata_json:
triples.append(Triple(
s=convert_value(t["s"]),
p=convert_value(t["p"]),
o=convert_value(t["o"]),
))
return triples
def load_document(api, user, doc_entry):
"""Fetch metadata and content for a document, then load into TrustGraph."""
doc_id = doc_entry["id"]
title = doc_entry["title"]
print(f" [{doc_id}] {title}")
print(f" fetching metadata...")
doc_json = fetch_document_metadata(doc_id)
doc = doc_json[0]
print(f" fetching content...")
content = fetch_document_content(doc_id)
print(f" loading into TrustGraph ({len(content) // 1024}KB)...")
metadata = convert_metadata(doc["metadata"])
api.add_document(
id=doc["id"],
metadata=metadata,
user=user,
kind=doc["kind"],
title=doc["title"],
comments=doc["comments"],
tags=doc["tags"],
document=content,
)
print(f" done.")
def load_documents(api, user, docs):
"""Load a list of documents."""
print(f"Loading {len(docs)} document(s)...\n")
for doc in docs:
try:
load_document(api, user, doc)
except Exception as e:
print(f" FAILED: {e}", file=sys.stderr)
print()
print("Complete.")
def main():
parser = argparse.ArgumentParser(
description="Browse and load documents from the TrustGraph public document library.",
)
parser.add_argument(
"-u", "--url", default=default_url,
help=f"TrustGraph API URL (default: {default_url})",
)
parser.add_argument(
"-U", "--user", default=default_user,
help=f"User ID (default: {default_user})",
)
parser.add_argument(
"-t", "--token", default=default_token,
help="Authentication token (default: $TRUSTGRAPH_TOKEN)",
)
sub = parser.add_subparsers(dest="command")
sub.add_parser("list", help="List all documents")
search_parser = sub.add_parser("search", help="Search documents")
search_parser.add_argument("query", help="Text to search for")
sub.add_parser("load-all", help="Load all documents into TrustGraph")
load_doc_parser = sub.add_parser("load-doc", help="Load a document by ID")
load_doc_parser.add_argument("id", help="Document ID (ebook number)")
load_match_parser = sub.add_parser(
"load-match", help="Load all documents matching a search term",
)
load_match_parser.add_argument("query", help="Text to search for")
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(1)
index = fetch_index()
if args.command in ("list", "search"):
if args.command == "list":
print_index(index)
else:
results = search_index(index, args.query)
if results:
print_index(results)
else:
print("No matches found.", file=sys.stderr)
sys.exit(1)
return
# Load commands need the API
api = Api(args.url, token=args.token).library()
if args.command == "load-all":
load_documents(api, args.user, index)
elif args.command == "load-doc":
matches = [d for d in index if str(d.get("id")) == args.id]
if not matches:
print(f"No document with ID '{args.id}' found.", file=sys.stderr)
sys.exit(1)
load_documents(api, args.user, matches)
elif args.command == "load-match":
results = search_index(index, args.query)
if results:
load_documents(api, args.user, results)
else:
print("No matches found.", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Analyse a captured agent trace JSON file and check DAG integrity.
Usage:
python analyse_trace.py react.json
python analyse_trace.py -u http://localhost:8088/ react.json
"""
import argparse
import asyncio
import json
import os
import sys
import websockets
DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
DEFAULT_USER = "trustgraph"
DEFAULT_COLLECTION = "default"
DEFAULT_FLOW = "default"
GRAPH = "urn:graph:retrieval"
# Namespace prefixes
PROV = "http://www.w3.org/ns/prov#"
RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
RDFS = "http://www.w3.org/2000/01/rdf-schema#"
TG = "https://trustgraph.ai/ns/"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
RDF_TYPE = RDF + "type"
TG_ANALYSIS = TG + "Analysis"
TG_TOOL_USE = TG + "ToolUse"
TG_OBSERVATION_TYPE = TG + "Observation"
TG_CONCLUSION = TG + "Conclusion"
TG_SYNTHESIS = TG + "Synthesis"
TG_QUESTION = TG + "Question"
def shorten(uri):
"""Shorten a URI for display."""
for prefix, short in [
(PROV, "prov:"), (RDF, "rdf:"), (RDFS, "rdfs:"), (TG, "tg:"),
]:
if isinstance(uri, str) and uri.startswith(prefix):
return short + uri[len(prefix):]
return str(uri)
async def fetch_triples(ws, flow, subject, user, collection, request_counter):
"""Query triples for a given subject URI."""
request_counter[0] += 1
req_id = f"q-{request_counter[0]}"
msg = {
"id": req_id,
"service": "triples",
"flow": flow,
"request": {
"s": {"t": "i", "i": subject},
"g": GRAPH,
"user": user,
"collection": collection,
"limit": 100,
},
}
await ws.send(json.dumps(msg))
while True:
raw = await ws.recv()
resp = json.loads(raw)
if resp.get("id") == req_id:
inner = resp.get("response", {})
if isinstance(inner, dict):
return inner.get("response", [])
return inner
def extract_term(term):
"""Extract value from wire-format term."""
if not term:
return ""
t = term.get("t", "")
if t == "i":
return term.get("i", "")
elif t == "l":
return term.get("v", "")
elif t == "t":
tr = term.get("tr", {})
return {
"s": extract_term(tr.get("s", {})),
"p": extract_term(tr.get("p", {})),
"o": extract_term(tr.get("o", {})),
}
return str(term)
def parse_triples(wire_triples):
"""Convert wire triples to (s, p, o) tuples."""
result = []
for t in wire_triples:
s = extract_term(t.get("s", {}))
p = extract_term(t.get("p", {}))
o = extract_term(t.get("o", {}))
result.append((s, p, o))
return result
def get_types(tuples):
"""Get rdf:type values from parsed triples."""
return {o for s, p, o in tuples if p == RDF_TYPE}
def get_derived_from(tuples):
"""Get prov:wasDerivedFrom targets from parsed triples."""
return [o for s, p, o in tuples if p == PROV_WAS_DERIVED_FROM]
async def analyse(path, url, flow, user, collection):
with open(path) as f:
messages = json.load(f)
print(f"Total messages: {len(messages)}")
print()
# ---- Pass 1: collect explain IDs and check streaming chunks ----
explain_ids = []
errors = []
for i, msg in enumerate(messages):
resp = msg.get("response", {})
chunk_type = resp.get("chunk_type", "?")
if chunk_type == "explain":
explain_id = resp.get("explain_id", "")
explain_ids.append(explain_id)
print(f" {i:3d} {chunk_type} {explain_id}")
else:
print(f" {i:3d} {chunk_type}")
# Rule 7: message_id on content chunks
if chunk_type in ("thought", "observation", "answer"):
mid = resp.get("message_id", "")
if not mid:
errors.append(
f"[msg {i}] {chunk_type} chunk missing message_id"
)
print()
print(f"Explain IDs ({len(explain_ids)}):")
for eid in explain_ids:
print(f" {eid}")
# ---- Pass 2: fetch triples for each explain ID ----
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
request_counter = [0]
# entity_id -> parsed triples [(s, p, o), ...]
entities = {}
print()
print("Fetching triples...")
print()
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as ws:
for eid in explain_ids:
wire = await fetch_triples(
ws, flow, eid, user, collection, request_counter,
)
tuples = parse_triples(wire) if isinstance(wire, list) else []
entities[eid] = tuples
print(f" {eid}")
for s, p, o in tuples:
o_short = str(o)
if len(o_short) > 80:
o_short = o_short[:77] + "..."
print(f" {shorten(p)} = {o_short}")
print()
# ---- Pass 3: check rules ----
all_ids = set(entities.keys())
# Collect entity metadata
roots = [] # entities with no wasDerivedFrom
conclusions = [] # tg:Conclusion entities
analyses = [] # tg:Analysis entities
observations = [] # tg:Observation entities
for eid, tuples in entities.items():
types = get_types(tuples)
parents = get_derived_from(tuples)
if not tuples:
errors.append(f"[{eid}] entity has no triples in store")
if not parents:
roots.append(eid)
if TG_CONCLUSION in types:
conclusions.append(eid)
if TG_ANALYSIS in types:
analyses.append(eid)
if TG_OBSERVATION_TYPE in types:
observations.append(eid)
# Rule 4: every non-root entity has wasDerivedFrom
if parents:
for parent in parents:
# Rule 5: parent exists in known entities
if parent not in all_ids:
errors.append(
f"[{eid}] wasDerivedFrom target not in explain set: "
f"{parent}"
)
# Rule 6: Analysis entities must have ToolUse type
if TG_ANALYSIS in types and TG_TOOL_USE not in types:
errors.append(
f"[{eid}] Analysis entity missing tg:ToolUse type"
)
# Rule 1: exactly one root
if len(roots) == 0:
errors.append("No root entity found (all have wasDerivedFrom)")
elif len(roots) > 1:
errors.append(
f"Multiple roots ({len(roots)}) — expected exactly 1:"
)
for r in roots:
types = get_types(entities[r])
type_labels = ", ".join(shorten(t) for t in types)
errors.append(f" root: {r} [{type_labels}]")
# Rule 2: exactly one terminal node (nothing derives from it)
# Build set of entities that are parents of something
has_children = set()
for eid, tuples in entities.items():
for parent in get_derived_from(tuples):
has_children.add(parent)
terminals = [eid for eid in all_ids if eid not in has_children]
if len(terminals) == 0:
errors.append("No terminal entity found (cycle?)")
elif len(terminals) > 1:
errors.append(
f"Multiple terminal entities ({len(terminals)}) — expected exactly 1:"
)
for t in terminals:
types = get_types(entities[t])
type_labels = ", ".join(shorten(ty) for ty in types)
errors.append(f" terminal: {t} [{type_labels}]")
# Rule 8: Observation should not derive from Analysis if a sub-trace
# exists as a sibling. Check: if an Analysis has both a Question child
# and an Observation child, the Observation should derive from the
# sub-trace's Synthesis, not from the Analysis.
for obs_id in observations:
obs_parents = get_derived_from(entities[obs_id])
for parent in obs_parents:
if parent in entities:
parent_types = get_types(entities[parent])
if TG_ANALYSIS in parent_types:
# Check if this Analysis also has a Question child
# (i.e. a sub-trace exists)
has_subtrace = False
for other_id, other_tuples in entities.items():
if other_id == obs_id:
continue
other_parents = get_derived_from(other_tuples)
other_types = get_types(other_tuples)
if (parent in other_parents
and TG_QUESTION in other_types):
has_subtrace = True
break
if has_subtrace:
errors.append(
f"[{obs_id}] Observation derives from Analysis "
f"{parent} which has a sub-trace — should derive "
f"from the sub-trace's Synthesis instead"
)
# ---- Report ----
print()
print("=" * 60)
if errors:
print(f"ERRORS ({len(errors)}):")
print()
for err in errors:
print(f" !! {err}")
else:
print("ALL CHECKS PASSED")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("input", help="JSON trace file")
parser.add_argument("-u", "--url", default=DEFAULT_URL)
parser.add_argument("-f", "--flow", default=DEFAULT_FLOW)
parser.add_argument("-U", "--user", default=DEFAULT_USER)
parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION)
args = parser.parse_args()
asyncio.run(analyse(
args.input, args.url, args.flow,
args.user, args.collection,
))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,81 @@
#!/usr/bin/env python3
"""
Connect to TrustGraph websocket, run an agent query, capture all
response messages to a JSON file.
Usage:
python ws_capture.py -q "What is the document about?" -o trace.json
python ws_capture.py -q "..." -u http://localhost:8088/ -o out.json
"""
import argparse
import asyncio
import json
import os
import websockets
DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
DEFAULT_USER = "trustgraph"
DEFAULT_COLLECTION = "default"
DEFAULT_FLOW = "default"
async def capture(url, flow, question, user, collection, output):
# Convert to ws URL
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=120) as ws:
request = {
"id": "capture",
"service": "agent",
"flow": flow,
"request": {
"question": question,
"user": user,
"collection": collection,
"streaming": True,
},
}
await ws.send(json.dumps(request))
messages = []
async for raw in ws:
msg = json.loads(raw)
if msg.get("id") != "capture":
continue
messages.append(msg)
if msg.get("complete"):
break
with open(output, "w") as f:
json.dump(messages, f, indent=2)
print(f"Captured {len(messages)} messages to {output}")
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("-q", "--question", required=True)
parser.add_argument("-o", "--output", default="trace.json")
parser.add_argument("-u", "--url", default=DEFAULT_URL)
parser.add_argument("-f", "--flow", default=DEFAULT_FLOW)
parser.add_argument("-U", "--user", default=DEFAULT_USER)
parser.add_argument("-C", "--collection", default=DEFAULT_COLLECTION)
args = parser.parse_args()
asyncio.run(capture(
args.url, args.flow, args.question,
args.user, args.collection, args.output,
))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Minimal example: download a text document in tiny chunks via websocket API
"""
import asyncio
import json
import base64
import websockets
async def main():
url = "ws://localhost:8088/api/v1/socket"
document_id = "test-chunked-doc-001"
chunk_size = 10 # Tiny chunks!
request_id = 0
async def send_request(ws, request):
nonlocal request_id
request_id += 1
msg = {
"id": f"req-{request_id}",
"service": "librarian",
"request": request
}
await ws.send(json.dumps(msg))
response = json.loads(await ws.recv())
if "error" in response:
raise Exception(response["error"])
return response.get("response", {})
async with websockets.connect(url) as ws:
print(f"Fetching document: {document_id}")
print(f"Chunk size: {chunk_size} bytes")
print()
chunk_index = 0
all_content = b""
while True:
resp = await send_request(ws, {
"operation": "stream-document",
"user": "trustgraph",
"document-id": document_id,
"chunk-index": chunk_index,
"chunk-size": chunk_size,
})
chunk_data = base64.b64decode(resp["content"])
total_chunks = resp["total-chunks"]
total_bytes = resp["total-bytes"]
print(f"Chunk {chunk_index}: {chunk_data}")
all_content += chunk_data
chunk_index += 1
if chunk_index >= total_chunks:
break
print()
print(f"Complete: {all_content}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,56 @@
#!/usr/bin/env python3
"""
Minimal example: upload a small text document via websocket API
"""
import asyncio
import json
import base64
import time
import websockets
async def main():
url = "ws://localhost:8088/api/v1/socket"
# Small text content
content = b"AAAAAAAAAABBBBBBBBBBCCCCCCCCCC"
request_id = 0
async def send_request(ws, request):
nonlocal request_id
request_id += 1
msg = {
"id": f"req-{request_id}",
"service": "librarian",
"request": request
}
await ws.send(json.dumps(msg))
response = json.loads(await ws.recv())
if "error" in response:
raise Exception(response["error"])
return response.get("response", {})
async with websockets.connect(url) as ws:
print(f"Uploading {len(content)} bytes...")
resp = await send_request(ws, {
"operation": "add-document",
"document-metadata": {
"id": "test-chunked-doc-001",
"time": int(time.time()),
"kind": "text/plain",
"title": "My Test Document",
"comments": "Small doc for chunk testing",
"user": "trustgraph",
"tags": ["test"],
"metadata": [],
},
"content": base64.b64encode(content).decode("utf-8"),
})
print("Done!")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""
WebSocket Test Client
A simple client to test the reverse gateway through the relay.
Connects to the relay's /in endpoint and allows sending test messages.
Usage:
python test_client.py [--uri URI] [--interactive]
"""
import asyncio
import json
import logging
import argparse
import uuid
from aiohttp import ClientSession, WSMsgType
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("test_client")
class TestClient:
"""Simple WebSocket test client"""
def __init__(self, uri: str):
self.uri = uri
self.session = None
self.ws = None
self.running = False
self.message_counter = 0
self.client_id = str(uuid.uuid4())[:8]
async def connect(self):
"""Connect to the WebSocket"""
self.session = ClientSession()
logger.info(f"Connecting to {self.uri}")
self.ws = await self.session.ws_connect(self.uri)
logger.info("Connected successfully")
async def disconnect(self):
"""Disconnect from WebSocket"""
if self.ws and not self.ws.closed:
await self.ws.close()
if self.session and not self.session.closed:
await self.session.close()
logger.info("Disconnected")
async def send_message(self, service: str, request_data: dict, flow: str = "default"):
"""Send a properly formatted TrustGraph message"""
self.message_counter += 1
message = {
"id": f"{self.client_id}-{self.message_counter}",
"service": service,
"request": request_data,
"flow": flow
}
message_json = json.dumps(message, indent=2)
logger.info(f"Sending message:\n{message_json}")
await self.ws.send_str(json.dumps(message))
async def listen_for_responses(self):
"""Listen for incoming messages"""
logger.info("Listening for responses...")
async for msg in self.ws:
if msg.type == WSMsgType.TEXT:
try:
response = json.loads(msg.data)
logger.info(f"Received response:\n{json.dumps(response, indent=2)}")
except json.JSONDecodeError:
logger.info(f"Received text: {msg.data}")
elif msg.type == WSMsgType.BINARY:
logger.info(f"Received binary data: {len(msg.data)} bytes")
elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error: {self.ws.exception()}")
break
else:
logger.info(f"Connection closed: {msg.type}")
break
async def interactive_mode(self):
"""Interactive mode for manual testing"""
print("\n=== Interactive Test Client ===")
print("Available commands:")
print(" text-completion - Test text completion service")
print(" agent - Test agent service")
print(" embeddings - Test embeddings service")
print(" custom - Send custom message")
print(" quit - Exit")
print()
# Start response listener
listen_task = asyncio.create_task(self.listen_for_responses())
try:
while True:
try:
command = input("Command> ").strip().lower()
if command == "quit":
break
elif command == "text-completion":
await self.send_message("text-completion", {
"system": "You are a helpful assistant.",
"prompt": "What is 2+2?"
})
elif command == "agent":
await self.send_message("agent", {
"question": "What is the capital of France?"
})
elif command == "embeddings":
await self.send_message("embeddings", {
"text": "Hello world"
})
elif command == "custom":
service = input("Service name> ").strip()
request_json = input("Request JSON> ").strip()
try:
request_data = json.loads(request_json)
await self.send_message(service, request_data)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
elif command == "":
continue
else:
print(f"Unknown command: {command}")
except KeyboardInterrupt:
break
except EOFError:
break
except Exception as e:
logger.error(f"Error in interactive mode: {e}")
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
async def run_predefined_tests(self):
"""Run a series of predefined tests"""
print("\n=== Running Predefined Tests ===")
# Start response listener
listen_task = asyncio.create_task(self.listen_for_responses())
try:
# Test 1: Text completion
print("\n1. Testing text-completion service...")
await self.send_message("text-completion", {
"system": "You are a helpful assistant.",
"prompt": "What is 2+2?"
})
await asyncio.sleep(2)
# Test 2: Agent
print("\n2. Testing agent service...")
await self.send_message("agent", {
"question": "What is the capital of France?"
})
await asyncio.sleep(2)
# Test 3: Embeddings
print("\n3. Testing embeddings service...")
await self.send_message("embeddings", {
"text": "Hello world"
})
await asyncio.sleep(2)
# Test 4: Invalid service
print("\n4. Testing invalid service...")
await self.send_message("nonexistent-service", {
"test": "data"
})
await asyncio.sleep(2)
print("\nTests completed. Waiting for any remaining responses...")
await asyncio.sleep(3)
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
async def main():
parser = argparse.ArgumentParser(
description="WebSocket Test Client for Reverse Gateway"
)
parser.add_argument(
'--uri',
default='ws://localhost:8080/in',
help='WebSocket URI to connect to (default: ws://localhost:8080/in)'
)
parser.add_argument(
'--interactive', '-i',
action='store_true',
help='Run in interactive mode'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Enable verbose logging'
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
client = TestClient(args.uri)
try:
await client.connect()
if args.interactive:
await client.interactive_mode()
else:
await client.run_predefined_tests()
except KeyboardInterrupt:
print("\nShutdown requested by user")
except Exception as e:
logger.error(f"Client error: {e}")
finally:
await client.disconnect()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,210 @@
#!/usr/bin/env python3
"""
WebSocket Relay Test Harness
This script creates a relay server with two WebSocket endpoints:
- /in - for test clients to connect to
- /out - for reverse gateway to connect to
Messages are bidirectionally relayed between the two connections.
Usage:
python websocket_relay.py [--port PORT] [--host HOST]
"""
import asyncio
import logging
import argparse
from aiohttp import web, WSMsgType
import weakref
from typing import Optional, Set
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("websocket_relay")
class WebSocketRelay:
"""WebSocket relay that forwards messages between 'in' and 'out' connections"""
def __init__(self):
self.in_connections: Set = weakref.WeakSet()
self.out_connections: Set = weakref.WeakSet()
async def handle_in_connection(self, request):
"""Handle incoming connections on /in endpoint"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.in_connections.add(ws)
logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
data = msg.data
logger.info(f"IN → OUT: {data}")
await self._forward_to_out(data)
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"IN → OUT: {len(data)} bytes (binary)")
await self._forward_to_out(data, binary=True)
elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'in' connection: {ws.exception()}")
break
else:
break
except Exception as e:
logger.error(f"Error in 'in' connection handler: {e}")
finally:
logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
return ws
async def handle_out_connection(self, request):
"""Handle outgoing connections on /out endpoint"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.out_connections.add(ws)
logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}")
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
data = msg.data
logger.info(f"OUT → IN: {data}")
await self._forward_to_in(data)
elif msg.type == WSMsgType.BINARY:
data = msg.data
logger.info(f"OUT → IN: {len(data)} bytes (binary)")
await self._forward_to_in(data, binary=True)
elif msg.type == WSMsgType.ERROR:
logger.error(f"WebSocket error on 'out' connection: {ws.exception()}")
break
else:
break
except Exception as e:
logger.error(f"Error in 'out' connection handler: {e}")
finally:
logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}")
return ws
async def _forward_to_out(self, data, binary=False):
"""Forward message from 'in' to all 'out' connections"""
if not self.out_connections:
logger.warning("No 'out' connections available to forward message")
return
closed_connections = []
for ws in list(self.out_connections):
try:
if ws.closed:
closed_connections.append(ws)
continue
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
except Exception as e:
logger.error(f"Error forwarding to 'out' connection: {e}")
closed_connections.append(ws)
# Clean up closed connections
for ws in closed_connections:
if ws in self.out_connections:
self.out_connections.discard(ws)
async def _forward_to_in(self, data, binary=False):
"""Forward message from 'out' to all 'in' connections"""
if not self.in_connections:
logger.warning("No 'in' connections available to forward message")
return
closed_connections = []
for ws in list(self.in_connections):
try:
if ws.closed:
closed_connections.append(ws)
continue
if binary:
await ws.send_bytes(data)
else:
await ws.send_str(data)
except Exception as e:
logger.error(f"Error forwarding to 'in' connection: {e}")
closed_connections.append(ws)
# Clean up closed connections
for ws in closed_connections:
if ws in self.in_connections:
self.in_connections.discard(ws)
async def create_app(relay):
"""Create the web application with routes"""
app = web.Application()
# Add routes
app.router.add_get('/in', relay.handle_in_connection)
app.router.add_get('/out', relay.handle_out_connection)
# Add a simple status endpoint
async def status(request):
status_info = {
'in_connections': len(relay.in_connections),
'out_connections': len(relay.out_connections),
'status': 'running'
}
return web.json_response(status_info)
app.router.add_get('/status', status)
app.router.add_get('/', status) # Root also shows status
return app
def main():
parser = argparse.ArgumentParser(
description="WebSocket Relay Test Harness"
)
parser.add_argument(
'--host',
default='localhost',
help='Host to bind to (default: localhost)'
)
parser.add_argument(
'--port',
type=int,
default=8080,
help='Port to bind to (default: 8080)'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Enable verbose logging'
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
relay = WebSocketRelay()
print(f"Starting WebSocket Relay on {args.host}:{args.port}")
print(f" 'in' endpoint: ws://{args.host}:{args.port}/in")
print(f" 'out' endpoint: ws://{args.host}:{args.port}/out")
print(f" Status: http://{args.host}:{args.port}/status")
print()
print("Usage:")
print(f" Test client connects to: ws://{args.host}:{args.port}/in")
print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out")
web.run_app(create_app(relay), host=args.host, port=args.port)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,227 @@
#!/usr/bin/env python3
"""
Load test triples into the triple store for testing tg-query-graph.
Tests all graph features:
- SPO with IRI objects
- SPO with literal objects
- Literals with XML datatypes
- Literals with language tags
- Quoted triples (RDF-star)
- Named graphs
"""
import asyncio
import json
import os
import websockets
# Configuration
API_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None)
FLOW = "default"
USER = "trustgraph"
COLLECTION = "default"
DOCUMENT_ID = "test-triples-001"
# Namespaces
EX = "http://example.org/"
RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
RDFS = "http://www.w3.org/2000/01/rdf-schema#"
XSD = "http://www.w3.org/2001/XMLSchema#"
TG = "https://trustgraph.ai/ns/"
def iri(value):
"""Build IRI term."""
return {"t": "i", "i": value}
def literal(value, datatype=None, language=None):
"""Build literal term with optional datatype or language."""
term = {"t": "l", "v": value}
if datatype:
term["dt"] = datatype
if language:
term["ln"] = language
return term
def quoted_triple(s, p, o):
"""Build quoted triple term (RDF-star)."""
return {
"t": "t",
"tr": {"s": s, "p": p, "o": o}
}
def triple(s, p, o, g=None):
"""Build a complete triple dict."""
t = {"s": s, "p": p, "o": o}
if g:
t["g"] = g
return t
# Test triples covering all features
TEST_TRIPLES = [
# 1. Basic SPO with IRI object
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDF}type"),
iri(f"{EX}Scientist")
),
# 2. SPO with IRI object (relationship)
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}discovered"),
iri(f"{EX}radium")
),
# 3. Simple literal (no datatype/language)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie")
),
# 4. Literal with language tag (English)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie", language="en")
),
# 5. Literal with language tag (French)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie", language="fr")
),
# 6. Literal with language tag (Polish)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Maria Sk\u0142odowska-Curie", language="pl")
),
# 7. Literal with xsd:integer datatype
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}birthYear"),
literal("1867", datatype=f"{XSD}integer")
),
# 8. Literal with xsd:date datatype
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}birthDate"),
literal("1867-11-07", datatype=f"{XSD}date")
),
# 9. Literal with xsd:boolean datatype
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}nobelLaureate"),
literal("true", datatype=f"{XSD}boolean")
),
# 10. Quoted triple in object position (RDF 1.2 style)
# "Wikipedia asserts that Marie Curie discovered radium"
triple(
iri(f"{EX}wikipedia"),
iri(f"{TG}asserts"),
quoted_triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}discovered"),
iri(f"{EX}radium")
)
),
# 11. Quoted triple with literal inside (object position)
# "NLP-v1.0 extracted that Marie Curie has label Marie Curie"
triple(
iri(f"{EX}nlp-v1"),
iri(f"{TG}extracted"),
quoted_triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie")
)
),
# 12. Triple in a named graph (g is plain string, not Term)
triple(
iri(f"{EX}radium"),
iri(f"{RDF}type"),
iri(f"{EX}Element"),
g=f"{EX}chemistry-graph"
),
# 13. Another triple in the same named graph
triple(
iri(f"{EX}radium"),
iri(f"{EX}atomicNumber"),
literal("88", datatype=f"{XSD}integer"),
g=f"{EX}chemistry-graph"
),
# 14. Triple in a different named graph
triple(
iri(f"{EX}pierre-curie"),
iri(f"{EX}spouseOf"),
iri(f"{EX}marie-curie"),
g=f"{EX}biography-graph"
),
]
async def load_triples():
"""Load test triples via WebSocket bulk import."""
# Convert HTTP URL to WebSocket URL
ws_url = API_URL.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url.rstrip('/')}/api/v1/flow/{FLOW}/import/triples"
if TOKEN:
ws_url = f"{ws_url}?token={TOKEN}"
metadata = {
"id": DOCUMENT_ID,
"metadata": [],
"user": USER,
"collection": COLLECTION
}
print(f"Connecting to {ws_url}...")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as websocket:
message = {
"metadata": metadata,
"triples": TEST_TRIPLES
}
print(f"Sending {len(TEST_TRIPLES)} test triples...")
await websocket.send(json.dumps(message))
print("Triples sent successfully!")
print("\nTest triples loaded:")
print(" - 2 basic IRI triples (type, relationship)")
print(" - 4 literal triples (plain + 3 languages: en, fr, pl)")
print(" - 3 typed literal triples (xsd:integer, xsd:date, xsd:boolean)")
print(" - 2 quoted triples (RDF-star provenance)")
print(" - 3 triples in named graphs (chemistry-graph, biography-graph)")
print(f"\nTotal: {len(TEST_TRIPLES)} triples")
print(f"User: {USER}, Collection: {COLLECTION}")
def main():
print("Loading test triples for tg-query-graph testing\n")
asyncio.run(load_triples())
print("\nDone! Now test with:")
print(" tg-query-graph -s http://example.org/marie-curie")
print(" tg-query-graph -p http://www.w3.org/2000/01/rdf-schema#label")
print(" tg-query-graph -o 'Marie Curie' --object-language en")
print(" tg-query-graph --format json | jq .")
if __name__ == "__main__":
main()

View file

@ -1,108 +0,0 @@
# API Gateway Changes: v1.8 to v2.1
## Summary
The API gateway gained new WebSocket service dispatchers for embeddings
queries, a new REST streaming endpoint for document content, and underwent
a significant wire format change from `Value` to `Term`. The "objects"
service was renamed to "rows".
---
## New WebSocket Service Dispatchers
These are new request/response services available through the WebSocket
multiplexer at `/api/v1/socket` (flow-scoped):
| Service Key | Description |
|-------------|-------------|
| `document-embeddings` | Queries document chunks by text similarity. Request/response uses `DocumentEmbeddingsRequest`/`DocumentEmbeddingsResponse` schemas. |
| `row-embeddings` | Queries structured data rows by text similarity on indexed fields. Request/response uses `RowEmbeddingsRequest`/`RowEmbeddingsResponse` schemas. |
These join the existing `graph-embeddings` dispatcher (which was already
present in v1.8 but may have been updated).
### Full list of WebSocket flow service dispatchers (v2.1)
Request/response services (via `/api/v1/flow/{flow}/service/{kind}` or
WebSocket mux):
- `agent`, `text-completion`, `prompt`, `mcp-tool`
- `graph-rag`, `document-rag`
- `embeddings`, `graph-embeddings`, `document-embeddings`
- `triples`, `rows`, `nlp-query`, `structured-query`, `structured-diag`
- `row-embeddings`
---
## New REST Endpoint
| Method | Path | Description |
|--------|------|-------------|
| `GET` | `/api/v1/document-stream` | Streams document content from the library as raw bytes. Query parameters: `user` (required), `document-id` (required), `chunk-size` (optional, default 1MB). Returns the document content in chunked transfer encoding, decoded from base64 internally. |
---
## Renamed Service: "objects" to "rows"
| v1.8 | v2.1 | Notes |
|------|------|-------|
| `objects_query.py` / `ObjectsQueryRequestor` | `rows_query.py` / `RowsQueryRequestor` | Schema changed from `ObjectsQueryRequest`/`ObjectsQueryResponse` to `RowsQueryRequest`/`RowsQueryResponse`. |
| `objects_import.py` / `ObjectsImport` | `rows_import.py` / `RowsImport` | Import dispatcher for structured data. |
The WebSocket service key changed from `"objects"` to `"rows"`, and the
import dispatcher key similarly changed from `"objects"` to `"rows"`.
---
## Wire Format Change: Value to Term
The serialization layer (`serialize.py`) was rewritten to use the new `Term`
type instead of the old `Value` type.
### Old format (v1.8 — `Value`)
```json
{"v": "http://example.org/entity", "e": true}
```
- `v`: the value (string)
- `e`: boolean flag indicating whether the value is a URI
### New format (v2.1 — `Term`)
IRIs:
```json
{"t": "i", "i": "http://example.org/entity"}
```
Literals:
```json
{"t": "l", "v": "some text", "d": "datatype-uri", "l": "en"}
```
Quoted triples (RDF-star):
```json
{"t": "r", "r": {"s": {...}, "p": {...}, "o": {...}}}
```
- `t`: type discriminator — `"i"` (IRI), `"l"` (literal), `"r"` (quoted triple), `"b"` (blank node)
- Serialization now delegates to `TermTranslator` and `TripleTranslator` from `trustgraph.messaging.translators.primitives`
### Other serialization changes
| Field | v1.8 | v2.1 |
|-------|------|------|
| Metadata | `metadata.metadata` (subgraph) | `metadata.root` (simple value) |
| Graph embeddings entity | `entity.vectors` (plural) | `entity.vector` (singular) |
| Document embeddings chunk | `chunk.vectors` + `chunk.chunk` (text) | `chunk.vector` + `chunk.chunk_id` (ID reference) |
---
## Breaking Changes
- **`Value` to `Term` wire format**: All clients sending/receiving triples, embeddings, or entity contexts through the gateway must update to the new Term format.
- **`objects` to `rows` rename**: WebSocket service key and import key changed.
- **Metadata field change**: `metadata.metadata` (a serialized subgraph) replaced by `metadata.root` (a simple value).
- **Embeddings field changes**: `vectors` (plural) became `vector` (singular); document embeddings now reference `chunk_id` instead of inline `chunk` text.
- **New `/api/v1/document-stream` endpoint**: Additive, not breaking.

File diff suppressed because one or more lines are too long

View file

@ -1,112 +0,0 @@
# CLI Changes: v1.8 to v2.1
## Summary
The CLI (`trustgraph-cli`) has significant additions focused on three themes:
**explainability/provenance**, **embeddings access**, and **graph querying**.
Two legacy tools were removed, one was renamed, and several existing tools
gained new capabilities.
---
## New CLI Tools
### Explainability & Provenance
| Command | Description |
|---------|-------------|
| `tg-list-explain-traces` | Lists all explainability sessions (GraphRAG and Agent) in a collection, showing session IDs, type, question text, and timestamps. |
| `tg-show-explain-trace` | Displays the full explainability trace for a session. For GraphRAG: Question, Exploration, Focus, Synthesis stages. For Agent: Session, Iterations (thought/action/observation), Final Answer. Auto-detects trace type. Supports `--show-provenance` to trace edges back to source documents. |
| `tg-show-extraction-provenance` | Given a document ID, traverses the provenance chain: Document -> Pages -> Chunks -> Edges, using `prov:wasDerivedFrom` relationships. Supports `--show-content` and `--max-content` options. |
### Embeddings
| Command | Description |
|---------|-------------|
| `tg-invoke-embeddings` | Converts text to a vector embedding via the embeddings service. Accepts one or more text inputs, returns vectors as lists of floats. |
| `tg-invoke-graph-embeddings` | Queries graph entities by text similarity using vector embeddings. Returns matching entities with similarity scores. |
| `tg-invoke-document-embeddings` | Queries document chunks by text similarity using vector embeddings. Returns matching chunk IDs with similarity scores. |
| `tg-invoke-row-embeddings` | Queries structured data rows by text similarity on indexed fields. Returns matching rows with index values and scores. Requires `--schema-name` and supports `--index-name`. |
### Graph Querying
| Command | Description |
|---------|-------------|
| `tg-query-graph` | Pattern-based triple store query. Unlike `tg-show-graph` (which dumps everything), this allows selective queries by any combination of subject, predicate, object, and graph. Auto-detects value types: IRIs (`http://...`, `urn:...`, `<...>`), quoted triples (`<<s p o>>`), and literals. |
| `tg-get-document-content` | Retrieves document content from the library by document ID. Can output to file or stdout, handles both text and binary content. |
---
## Removed CLI Tools
| Command | Notes |
|---------|-------|
| `tg-load-pdf` | Removed. Document loading is now handled through the library/processing pipeline. |
| `tg-load-text` | Removed. Document loading is now handled through the library/processing pipeline. |
---
## Renamed CLI Tools
| Old Name | New Name | Notes |
|----------|----------|-------|
| `tg-invoke-objects-query` | `tg-invoke-rows-query` | Reflects the terminology rename from "objects" to "rows" for structured data. |
---
## Significant Changes to Existing Tools
### `tg-invoke-graph-rag`
- **Explainability support**: Now supports a 4-stage explainability pipeline (Question, Grounding/Exploration, Focus, Synthesis) with inline provenance event display.
- **Streaming**: Uses WebSocket streaming for real-time output.
- **Provenance tracing**: Can trace selected edges back to source documents via reification and `prov:wasDerivedFrom` chains.
- Grew from ~30 lines to ~760 lines to accommodate the full explainability pipeline.
### `tg-invoke-document-rag`
- **Explainability support**: Added `question_explainable()` mode that streams Document RAG responses with inline provenance events (Question, Grounding, Exploration, Synthesis stages).
### `tg-invoke-agent`
- **Explainability support**: Added `question_explainable()` mode showing provenance events inline during agent execution (Question, Analysis, Conclusion, AgentThought, AgentObservation, AgentAnswer).
- Verbose mode shows thought/observation streams with emoji prefixes.
### `tg-show-graph`
- **Streaming mode**: Now uses `triples_query_stream()` with configurable batch sizes for lower time-to-first-result and reduced memory overhead.
- **Named graph support**: New `--graph` filter option. Recognises named graphs:
- Default graph (empty): Core knowledge facts
- `urn:graph:source`: Extraction provenance
- `urn:graph:retrieval`: Query-time explainability
- **Show graph column**: New `--show-graph` flag to display the named graph for each triple.
- **Configurable limits**: New `--limit` and `--batch-size` options.
### `tg-graph-to-turtle`
- **RDF-star support**: Now handles quoted triples (RDF-star reification).
- **Streaming mode**: Uses streaming for lower time-to-first-processing.
- **Wire format handling**: Updated to use the new term wire format (`{"t": "i", "i": uri}` for IRIs, `{"t": "l", "v": value}` for literals, `{"t": "r", "r": {...}}` for quoted triples).
- **Named graph support**: New `--graph` filter option.
### `tg-set-tool`
- **New tool type**: `row-embeddings-query` for semantic search on structured data indexes.
- **New options**: `--schema-name`, `--index-name`, `--limit` for configuring row embeddings query tools.
### `tg-show-tools`
- Displays the new `row-embeddings-query` tool type with its `schema-name`, `index-name`, and `limit` fields.
### `tg-load-knowledge`
- **Progress reporting**: Now counts and reports triples and entity contexts loaded per file and in total.
- **Term format update**: Entity contexts now use the new Term format (`{"t": "i", "i": uri}`) instead of the old Value format (`{"v": entity, "e": True}`).
---
## Breaking Changes
- **Terminology rename**: The `Value` schema was renamed to `Term` across the system (PR #622). This affects the wire format used by CLI tools that interact with the graph store. The new format uses `{"t": "i", "i": uri}` for IRIs and `{"t": "l", "v": value}` for literals, replacing the old `{"v": ..., "e": ...}` format.
- **`tg-invoke-objects-query` renamed** to `tg-invoke-rows-query`.
- **`tg-load-pdf` and `tg-load-text` removed**.

View file

@ -911,7 +911,7 @@ results = flow.graph_embeddings_query(
# results contains {"entities": [{"entity": {...}, "score": 0.95}, ...]}
```
### `graph_rag(self, query, user='trustgraph', collection='default', entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2)`
### `graph_rag(self, query, user='trustgraph', collection='default', entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2, edge_score_limit=30, edge_limit=25)`
Execute graph-based Retrieval-Augmented Generation (RAG) query.
@ -927,6 +927,8 @@ traversing entity relationships, then generates a response using an LLM.
- `triple_limit`: Maximum triples per entity (default: 30)
- `max_subgraph_size`: Maximum total triples in subgraph (default: 150)
- `max_path_length`: Maximum traversal depth (default: 2)
- `edge_score_limit`: Max edges for semantic pre-filter (default: 50)
- `edge_limit`: Max edges after LLM scoring (default: 25)
**Returns:** str: Generated response incorporating graph context
@ -1216,6 +1218,23 @@ Select matching schemas for a data sample using prompt analysis.
**Returns:** dict with schema_matches array and metadata
### `sparql_query(self, query, user='trustgraph', collection='default', limit=10000)`
Execute a SPARQL query against the knowledge graph.
**Arguments:**
- `query`: SPARQL 1.1 query string
- `user`: User/keyspace identifier (default: "trustgraph")
- `collection`: Collection identifier (default: "default")
- `limit`: Safety limit on results (default: 10000)
**Returns:** dict with query results. Structure depends on query type: - SELECT: {"query-type": "select", "variables": [...], "bindings": [...]} - ASK: {"query-type": "ask", "ask-result": bool} - CONSTRUCT/DESCRIBE: {"query-type": "construct", "triples": [...]}
**Raises:**
- `ProtocolException`: If an error occurs
### `structured_query(self, question, user='trustgraph', collection='default')`
Execute a natural language question against structured data.
@ -1937,54 +1956,24 @@ for triple in results.get("triples", []):
from trustgraph.api import SocketClient
```
Synchronous WebSocket client for streaming operations.
Synchronous WebSocket client with persistent connection.
Provides a synchronous interface to WebSocket-based TrustGraph services,
wrapping async websockets library with synchronous generators for ease of use.
Supports streaming responses from agents, RAG queries, and text completions.
Note: This is a synchronous wrapper around async WebSocket operations. For
true async support, use AsyncSocketClient instead.
Maintains a single websocket connection and multiplexes requests
by ID via a background reader task. Provides synchronous generators
for streaming responses.
### Methods
### `__init__(self, url: str, timeout: int, token: str | None) -> None`
Initialize synchronous WebSocket client.
**Arguments:**
- `url`: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
- `timeout`: WebSocket timeout in seconds
- `token`: Optional bearer token for authentication
Initialize self. See help(type(self)) for accurate signature.
### `close(self) -> None`
Close WebSocket connections.
Note: Cleanup is handled automatically by context managers in async code.
Close the persistent WebSocket connection.
### `flow(self, flow_id: str) -> 'SocketFlowInstance'`
Get a flow instance for WebSocket streaming operations.
**Arguments:**
- `flow_id`: Flow identifier
**Returns:** SocketFlowInstance: Flow instance with streaming methods
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Stream agent responses
for chunk in flow.agent(question="Hello", user="trustgraph", streaming=True):
print(chunk.content, end='', flush=True)
```
---
@ -1997,618 +1986,82 @@ from trustgraph.api import SocketFlowInstance
Synchronous WebSocket flow instance for streaming operations.
Provides the same interface as REST FlowInstance but with WebSocket-based
streaming support for real-time responses. All methods support an optional
`streaming` parameter to enable incremental result delivery.
streaming support for real-time responses.
### Methods
### `__init__(self, client: trustgraph.api.socket_client.SocketClient, flow_id: str) -> None`
Initialize socket flow instance.
**Arguments:**
- `client`: Parent SocketClient
- `flow_id`: Flow identifier
Initialize self. See help(type(self)) for accurate signature.
### `agent(self, question: str, user: str, state: Dict[str, Any] | None = None, group: str | None = None, history: List[Dict[str, Any]] | None = None, streaming: bool = False, **kwargs: Any) -> Dict[str, Any] | Iterator[trustgraph.api.types.StreamingChunk]`
Execute an agent operation with streaming support.
Agents can perform multi-step reasoning with tool use. This method always
returns streaming chunks (thoughts, observations, answers) even when
streaming=False, to show the agent's reasoning process.
**Arguments:**
- `question`: User question or instruction
- `user`: User identifier
- `state`: Optional state dictionary for stateful conversations
- `group`: Optional group identifier for multi-user contexts
- `history`: Optional conversation history as list of message dicts
- `streaming`: Enable streaming mode (default: False)
- `**kwargs`: Additional parameters passed to the agent service
**Returns:** Iterator[StreamingChunk]: Stream of agent thoughts, observations, and answers
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Stream agent reasoning
for chunk in flow.agent(
question="What is quantum computing?",
user="trustgraph",
streaming=True
):
if isinstance(chunk, AgentThought):
print(f"[Thinking] {chunk.content}")
elif isinstance(chunk, AgentObservation):
print(f"[Observation] {chunk.content}")
elif isinstance(chunk, AgentAnswer):
print(f"[Answer] {chunk.content}")
```
### `agent_explain(self, question: str, user: str, collection: str, state: Dict[str, Any] | None = None, group: str | None = None, history: List[Dict[str, Any]] | None = None, **kwargs: Any) -> Iterator[trustgraph.api.types.StreamingChunk | trustgraph.api.types.ProvenanceEvent]`
Execute an agent operation with explainability support.
Streams both content chunks (AgentThought, AgentObservation, AgentAnswer)
and provenance events (ProvenanceEvent). Provenance events contain URIs
that can be fetched using ExplainabilityClient to get detailed information
about the agent's reasoning process.
Agent trace consists of:
- Session: The initial question and session metadata
- Iterations: Each thought/action/observation cycle
- Conclusion: The final answer
**Arguments:**
- `question`: User question or instruction
- `user`: User identifier
- `collection`: Collection identifier for provenance storage
- `state`: Optional state dictionary for stateful conversations
- `group`: Optional group identifier for multi-user contexts
- `history`: Optional conversation history as list of message dicts
- `**kwargs`: Additional parameters passed to the agent service
- `Yields`:
- `Union[StreamingChunk, ProvenanceEvent]`: Agent chunks and provenance events
**Example:**
```python
from trustgraph.api import Api, ExplainabilityClient, ProvenanceEvent
from trustgraph.api import AgentThought, AgentObservation, AgentAnswer
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
provenance_ids = []
for item in flow.agent_explain(
question="What is the capital of France?",
user="trustgraph",
collection="default"
):
if isinstance(item, AgentThought):
print(f"[Thought] {item.content}")
elif isinstance(item, AgentObservation):
print(f"[Observation] {item.content}")
elif isinstance(item, AgentAnswer):
print(f"[Answer] {item.content}")
elif isinstance(item, ProvenanceEvent):
provenance_ids.append(item.explain_id)
# Fetch session trace after completion
if provenance_ids:
trace = explain_client.fetch_agent_trace(
provenance_ids[0], # Session URI is first
graph="urn:graph:retrieval",
user="trustgraph",
collection="default"
)
```
### `document_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any) -> Dict[str, Any]`
Query document chunks using semantic similarity.
**Arguments:**
- `text`: Query text for semantic search
- `user`: User/keyspace identifier
- `collection`: Collection identifier
- `limit`: Maximum number of results (default: 10)
- `**kwargs`: Additional parameters passed to the service
**Returns:** dict: Query results with chunk_ids of matching document chunks
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
results = flow.document_embeddings_query(
text="machine learning algorithms",
user="trustgraph",
collection="research-papers",
limit=5
)
# results contains {"chunks": [{"chunk_id": "...", "score": 0.95}, ...]}
```
### `document_rag(self, query: str, user: str, collection: str, doc_limit: int = 10, streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
Execute document-based RAG query with optional streaming.
Uses vector embeddings to find relevant document chunks, then generates
a response using an LLM. Streaming mode delivers results incrementally.
**Arguments:**
- `query`: Natural language query
- `user`: User/keyspace identifier
- `collection`: Collection identifier
- `doc_limit`: Maximum document chunks to retrieve (default: 10)
- `streaming`: Enable streaming mode (default: False)
- `**kwargs`: Additional parameters passed to the service
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Streaming document RAG
for chunk in flow.document_rag(
query="Summarize the key findings",
user="trustgraph",
collection="research-papers",
doc_limit=5,
streaming=True
):
print(chunk, end='', flush=True)
```
### `document_rag_explain(self, query: str, user: str, collection: str, doc_limit: int = 10, **kwargs: Any) -> Iterator[trustgraph.api.types.RAGChunk | trustgraph.api.types.ProvenanceEvent]`
Execute document-based RAG query with explainability support.
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
Provenance events contain URIs that can be fetched using ExplainabilityClient
to get detailed information about how the response was generated.
Document RAG trace consists of:
- Question: The user's query
- Exploration: Chunks retrieved from document store (chunk_count)
- Synthesis: The generated answer
**Arguments:**
- `query`: Natural language query
- `user`: User/keyspace identifier
- `collection`: Collection identifier
- `doc_limit`: Maximum document chunks to retrieve (default: 10)
- `**kwargs`: Additional parameters passed to the service
- `Yields`:
- `Union[RAGChunk, ProvenanceEvent]`: Content chunks and provenance events
**Example:**
```python
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
for item in flow.document_rag_explain(
query="Summarize the key findings",
user="trustgraph",
collection="research-papers",
doc_limit=5
):
if isinstance(item, RAGChunk):
print(item.content, end='', flush=True)
elif isinstance(item, ProvenanceEvent):
# Fetch entity details
entity = explain_client.fetch_entity(
item.explain_id,
graph=item.explain_graph,
user="trustgraph",
collection="research-papers"
)
print(f"Event: {entity}", file=sys.stderr)
```
### `embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]`
Generate vector embeddings for one or more texts.
**Arguments:**
- `texts`: List of input texts to embed
- `**kwargs`: Additional parameters passed to the service
**Returns:** dict: Response containing vectors (one set per input text)
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
result = flow.embeddings(["quantum computing"])
vectors = result.get("vectors", [])
```
### `graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any) -> Dict[str, Any]`
Query knowledge graph entities using semantic similarity.
**Arguments:**
- `text`: Query text for semantic search
- `user`: User/keyspace identifier
- `collection`: Collection identifier
- `limit`: Maximum number of results (default: 10)
- `**kwargs`: Additional parameters passed to the service
**Returns:** dict: Query results with similar entities
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
results = flow.graph_embeddings_query(
text="physicist who discovered radioactivity",
user="trustgraph",
collection="scientists",
limit=5
)
```
### `graph_rag(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
### `graph_rag(self, query: str, user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, max_subgraph_size: int = 1000, max_path_length: int = 2, edge_score_limit: int = 30, edge_limit: int = 25, streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
Execute graph-based RAG query with optional streaming.
Uses knowledge graph structure to find relevant context, then generates
a response using an LLM. Streaming mode delivers results incrementally.
**Arguments:**
- `query`: Natural language query
- `user`: User/keyspace identifier
- `collection`: Collection identifier
- `max_subgraph_size`: Maximum total triples in subgraph (default: 1000)
- `max_subgraph_count`: Maximum number of subgraphs (default: 5)
- `max_entity_distance`: Maximum traversal depth (default: 3)
- `streaming`: Enable streaming mode (default: False)
- `**kwargs`: Additional parameters passed to the service
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Streaming graph RAG
for chunk in flow.graph_rag(
query="Tell me about Marie Curie",
user="trustgraph",
collection="scientists",
streaming=True
):
print(chunk, end='', flush=True)
```
### `graph_rag_explain(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_entity_distance: int = 3, **kwargs: Any) -> Iterator[trustgraph.api.types.RAGChunk | trustgraph.api.types.ProvenanceEvent]`
### `graph_rag_explain(self, query: str, user: str, collection: str, entity_limit: int = 50, triple_limit: int = 30, max_subgraph_size: int = 1000, max_path_length: int = 2, edge_score_limit: int = 30, edge_limit: int = 25, **kwargs: Any) -> Iterator[trustgraph.api.types.RAGChunk | trustgraph.api.types.ProvenanceEvent]`
Execute graph-based RAG query with explainability support.
Streams both content chunks (RAGChunk) and provenance events (ProvenanceEvent).
Provenance events contain URIs that can be fetched using ExplainabilityClient
to get detailed information about how the response was generated.
**Arguments:**
- `query`: Natural language query
- `user`: User/keyspace identifier
- `collection`: Collection identifier
- `max_subgraph_size`: Maximum total triples in subgraph (default: 1000)
- `max_subgraph_count`: Maximum number of subgraphs (default: 5)
- `max_entity_distance`: Maximum traversal depth (default: 3)
- `**kwargs`: Additional parameters passed to the service
- `Yields`:
- `Union[RAGChunk, ProvenanceEvent]`: Content chunks and provenance events
**Example:**
```python
from trustgraph.api import Api, ExplainabilityClient, RAGChunk, ProvenanceEvent
socket = api.socket()
flow = socket.flow("default")
explain_client = ExplainabilityClient(flow)
provenance_ids = []
response_text = ""
for item in flow.graph_rag_explain(
query="Tell me about Marie Curie",
user="trustgraph",
collection="scientists"
):
if isinstance(item, RAGChunk):
response_text += item.content
print(item.content, end='', flush=True)
elif isinstance(item, ProvenanceEvent):
provenance_ids.append(item.provenance_id)
# Fetch explainability details
for prov_id in provenance_ids:
entity = explain_client.fetch_entity(
prov_id,
graph="urn:graph:retrieval",
user="trustgraph",
collection="scientists"
)
print(f"Entity: {entity}")
```
### `mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]`
Execute a Model Context Protocol (MCP) tool.
**Arguments:**
- `name`: Tool name/identifier
- `parameters`: Tool parameters dictionary
- `**kwargs`: Additional parameters passed to the service
**Returns:** dict: Tool execution result
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
result = flow.mcp_tool(
name="search-web",
parameters={"query": "latest AI news", "limit": 5}
)
```
### `prompt(self, id: str, variables: Dict[str, str], streaming: bool = False, **kwargs: Any) -> str | Iterator[str]`
Execute a prompt template with optional streaming.
**Arguments:**
- `id`: Prompt template identifier
- `variables`: Dictionary of variable name to value mappings
- `streaming`: Enable streaming mode (default: False)
- `**kwargs`: Additional parameters passed to the service
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Streaming prompt execution
for chunk in flow.prompt(
id="summarize-template",
variables={"topic": "quantum computing", "length": "brief"},
streaming=True
):
print(chunk, end='', flush=True)
```
### `row_embeddings_query(self, text: str, schema_name: str, user: str = 'trustgraph', collection: str = 'default', index_name: str | None = None, limit: int = 10, **kwargs: Any) -> Dict[str, Any]`
Query row data using semantic similarity on indexed fields.
Finds rows whose indexed field values are semantically similar to the
input text, using vector embeddings. This enables fuzzy/semantic matching
on structured data.
**Arguments:**
- `text`: Query text for semantic search
- `schema_name`: Schema name to search within
- `user`: User/keyspace identifier (default: "trustgraph")
- `collection`: Collection identifier (default: "default")
- `index_name`: Optional index name to filter search to specific index
- `limit`: Maximum number of results (default: 10)
- `**kwargs`: Additional parameters passed to the service
**Returns:** dict: Query results with matches containing index_name, index_value, text, and score
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Search for customers by name similarity
results = flow.row_embeddings_query(
text="John Smith",
schema_name="customers",
user="trustgraph",
collection="sales",
limit=5
)
# Filter to specific index
results = flow.row_embeddings_query(
text="machine learning engineer",
schema_name="employees",
index_name="job_title",
limit=10
)
```
### `rows_query(self, query: str, user: str, collection: str, variables: Dict[str, Any] | None = None, operation_name: str | None = None, **kwargs: Any) -> Dict[str, Any]`
Execute a GraphQL query against structured rows.
**Arguments:**
### `sparql_query_stream(self, query: str, user: str = 'trustgraph', collection: str = 'default', limit: int = 10000, batch_size: int = 20, **kwargs: Any) -> Iterator[Dict[str, Any]]`
- `query`: GraphQL query string
- `user`: User/keyspace identifier
- `collection`: Collection identifier
- `variables`: Optional query variables dictionary
- `operation_name`: Optional operation name for multi-operation documents
- `**kwargs`: Additional parameters passed to the service
**Returns:** dict: GraphQL response with data, errors, and/or extensions
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
query = '''
{
scientists(limit: 10) {
name
field
discoveries
}
}
'''
result = flow.rows_query(
query=query,
user="trustgraph",
collection="scientists"
)
```
Execute a SPARQL query with streaming batches.
### `text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> str | Iterator[str]`
Execute text completion with optional streaming.
**Arguments:**
- `system`: System prompt defining the assistant's behavior
- `prompt`: User prompt/question
- `streaming`: Enable streaming mode (default: False)
- `**kwargs`: Additional parameters passed to the service
**Returns:** Union[str, Iterator[str]]: Complete response or stream of text chunks
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Non-streaming
response = flow.text_completion(
system="You are helpful",
prompt="Explain quantum computing",
streaming=False
)
print(response)
# Streaming
for chunk in flow.text_completion(
system="You are helpful",
prompt="Explain quantum computing",
streaming=True
):
print(chunk, end='', flush=True)
```
### `triples_query(self, s: str | Dict[str, Any] | None = None, p: str | Dict[str, Any] | None = None, o: str | Dict[str, Any] | None = None, g: str | None = None, user: str | None = None, collection: str | None = None, limit: int = 100, **kwargs: Any) -> List[Dict[str, Any]]`
Query knowledge graph triples using pattern matching.
**Arguments:**
- `s`: Subject filter - URI string, Term dict, or None for wildcard
- `p`: Predicate filter - URI string, Term dict, or None for wildcard
- `o`: Object filter - URI/literal string, Term dict, or None for wildcard
- `g`: Named graph filter - URI string or None for all graphs
- `user`: User/keyspace identifier (optional)
- `collection`: Collection identifier (optional)
- `limit`: Maximum results to return (default: 100)
- `**kwargs`: Additional parameters passed to the service
**Returns:** List[Dict]: List of matching triples in wire format
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
# Find all triples about a specific subject
triples = flow.triples_query(
s="http://example.org/person/marie-curie",
user="trustgraph",
collection="scientists"
)
# Query with named graph filter
triples = flow.triples_query(
s="urn:trustgraph:session:abc123",
g="urn:graph:retrieval",
user="trustgraph",
collection="default"
)
```
### `triples_query_stream(self, s: str | Dict[str, Any] | None = None, p: str | Dict[str, Any] | None = None, o: str | Dict[str, Any] | None = None, g: str | None = None, user: str | None = None, collection: str | None = None, limit: int = 100, batch_size: int = 20, **kwargs: Any) -> Iterator[List[Dict[str, Any]]]`
Query knowledge graph triples with streaming batches.
Yields batches of triples as they arrive, reducing time-to-first-result
and memory overhead for large result sets.
**Arguments:**
- `s`: Subject filter - URI string, Term dict, or None for wildcard
- `p`: Predicate filter - URI string, Term dict, or None for wildcard
- `o`: Object filter - URI/literal string, Term dict, or None for wildcard
- `g`: Named graph filter - URI string or None for all graphs
- `user`: User/keyspace identifier (optional)
- `collection`: Collection identifier (optional)
- `limit`: Maximum results to return (default: 100)
- `batch_size`: Triples per batch (default: 20)
- `**kwargs`: Additional parameters passed to the service
- `Yields`:
- `List[Dict]`: Batches of triples in wire format
**Example:**
```python
socket = api.socket()
flow = socket.flow("default")
for batch in flow.triples_query_stream(
user="trustgraph",
collection="default"
):
for triple in batch:
print(triple["s"], triple["p"], triple["o"])
```
---
@ -2618,17 +2071,35 @@ for batch in flow.triples_query_stream(
from trustgraph.api import AsyncSocketClient
```
Asynchronous WebSocket client
Asynchronous WebSocket client with persistent connection.
Maintains a single websocket connection and multiplexes requests
by ID, routing responses via a background reader task.
Use as an async context manager for proper lifecycle management:
async with AsyncSocketClient(url, timeout, token) as client:
result = await client._send_request(...)
Or call connect()/aclose() manually.
### Methods
### `__aenter__(self)`
### `__aexit__(self, exc_type, exc_val, exc_tb)`
### `__init__(self, url: str, timeout: int, token: str | None)`
Initialize self. See help(type(self)) for accurate signature.
### `aclose(self)`
Close WebSocket connection
Close the persistent WebSocket connection cleanly.
### `connect(self)`
Establish the persistent websocket connection.
### `flow(self, flow_id: str)`
@ -3151,7 +2622,10 @@ Detect whether a session is GraphRAG or Agent type.
Fetch the complete Agent trace starting from a session URI.
Follows the provenance chain: Question -> Analysis(s) -> Conclusion
Follows the provenance chain for all patterns:
- ReAct: Question -> Analysis(s) -> Conclusion
- Supervisor: Question -> Decomposition -> Finding(s) -> Synthesis
- Plan-then-Execute: Question -> Plan -> StepResult(s) -> Synthesis
**Arguments:**
@ -3162,7 +2636,7 @@ Follows the provenance chain: Question -> Analysis(s) -> Conclusion
- `api`: TrustGraph Api instance for librarian access (optional)
- `max_content`: Maximum content length for conclusion
**Returns:** Dict with question, iterations (Analysis list), conclusion entities
**Returns:** Dict with question, steps (mixed entity list), conclusion/synthesis
### `fetch_docrag_trace(self, question_uri: str, graph: str | None = None, user: str | None = None, collection: str | None = None, api: Any = None, max_content: int = 10000) -> Dict[str, Any]`
@ -3423,7 +2897,7 @@ Initialize self. See help(type(self)) for accurate signature.
from trustgraph.api import Analysis
```
Analysis entity - one think/act/observe cycle (Agent only).
Analysis+ToolUse entity - decision + tool call (Agent only).
**Fields:**
@ -3432,11 +2906,33 @@ Analysis entity - one think/act/observe cycle (Agent only).
- `action`: <class 'str'>
- `arguments`: <class 'str'>
- `thought`: <class 'str'>
- `observation`: <class 'str'>
### Methods
### `__init__(self, uri: str, entity_type: str = '', action: str = '', arguments: str = '', thought: str = '', observation: str = '') -> None`
### `__init__(self, uri: str, entity_type: str = '', action: str = '', arguments: str = '', thought: str = '') -> None`
Initialize self. See help(type(self)) for accurate signature.
---
## `Observation`
```python
from trustgraph.api import Observation
```
Observation entity - standalone tool result (Agent only).
**Fields:**
- `uri`: <class 'str'>
- `entity_type`: <class 'str'>
- `document`: <class 'str'>
### Methods
### `__init__(self, uri: str, entity_type: str = '', document: str = '') -> None`
Initialize self. See help(type(self)) for accurate signature.
@ -3761,10 +3257,11 @@ These chunks show how the agent is thinking about the problem.
- `content`: <class 'str'>
- `end_of_message`: <class 'bool'>
- `chunk_type`: <class 'str'>
- `message_id`: <class 'str'>
### Methods
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'thought') -> None`
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'thought', message_id: str = '') -> None`
Initialize self. See help(type(self)) for accurate signature.
@ -3787,10 +3284,11 @@ These chunks show what the agent learned from using tools.
- `content`: <class 'str'>
- `end_of_message`: <class 'bool'>
- `chunk_type`: <class 'str'>
- `message_id`: <class 'str'>
### Methods
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'observation') -> None`
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'observation', message_id: str = '') -> None`
Initialize self. See help(type(self)) for accurate signature.
@ -3818,10 +3316,11 @@ its reasoning and tool use.
- `end_of_message`: <class 'bool'>
- `chunk_type`: <class 'str'>
- `end_of_dialog`: <class 'bool'>
- `message_id`: <class 'str'>
### Methods
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'final-answer', end_of_dialog: bool = False) -> None`
### `__init__(self, content: str, end_of_message: bool = False, chunk_type: str = 'final-answer', end_of_dialog: bool = False, message_id: str = '') -> None`
Initialize self. See help(type(self)) for accurate signature.
@ -3864,7 +3363,7 @@ from trustgraph.api import ProvenanceEvent
Provenance event for explainability.
Emitted during GraphRAG queries when explainable mode is enabled.
Emitted during retrieval queries when explainable mode is enabled.
Each event represents a provenance node created during query processing.
**Fields:**
@ -3872,10 +3371,12 @@ Each event represents a provenance node created during query processing.
- `explain_id`: <class 'str'>
- `explain_graph`: <class 'str'>
- `event_type`: <class 'str'>
- `entity`: <class 'object'>
- `triples`: <class 'list'>
### Methods
### `__init__(self, explain_id: str, explain_graph: str = '', event_type: str = '') -> None`
### `__init__(self, explain_id: str, explain_graph: str = '', event_type: str = '', entity: object = None, triples: list = <factory>) -> None`
Initialize self. See help(type(self)) for accurate signature.

View file

@ -219,8 +219,8 @@ TG_ANSWER = TG + "answer"
| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders |
| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators |
| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions |
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id and explain_graph to DocumentRagResponse |
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields |
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id, explain_graph, and explain_triples to DocumentRagResponse |
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields including inline triples |
| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic |
| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples |
| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback |

View file

@ -0,0 +1,939 @@
# TrustGraph Agent Orchestration — Technical Specification
## Overview
This specification describes the extension of TrustGraph's agent architecture
from a single ReACT execution pattern to a multi-pattern orchestration
model. The existing Pulsar-based self-queuing loop is pattern-agnostic — the
same infrastructure supports ReACT, Plan-then-Execute, Supervisor/Subagent
fan-out, and other execution strategies without changes to the message
transport. The extension adds a routing layer that selects the appropriate
pattern for each task, a set of pattern implementations that share common
iteration infrastructure, and a fan-out/fan-in mechanism for multi-agent
coordination.
The central design principle is that
**trust and explainability are structural properties of the architecture**,
achieved by constraining LLM decisions to
graph-defined option sets and recording those constraints in the execution
trace.
---
## Background
### Existing Architecture
The current agent manager is built on the ReACT pattern (Reasoning + Acting)
with these properties:
- **Self-queuing loop**: Each iteration emits a new Pulsar message carrying
the accumulated history. The agent manager picks this up and runs the next
iteration.
- **Stateless agent manager**: No in-process state. All state lives in the
message payload.
- **Natural parallelism**: Multiple independent agent requests are handled
concurrently across Pulsar consumers.
- **Durability**: Crash recovery is inherent — the message survives process
failure.
- **Real-time feedback**: Streaming thought, action, observation and answer
chunks are emitted as iterations complete.
- **Tool calling and MCP invocation**: Tool calls into knowledge graphs,
external services, and MCP-connected systems.
- **Decision traces written to the knowledge graph**: Every iteration records
PROV-O triples — session, analysis, and conclusion entities — forming the
basis of explainability.
### Current Message Flow
```
AgentRequest arrives (question, history=[], state, group, session_id)
Filter tools by group/state
AgentManager.react() → LLM call → parse → Action or Final
│ │
│ [Action] │ [Final]
▼ ▼
Execute tool, capture observation Emit conclusion triples
Emit iteration triples Send AgentResponse
Append to history (end_of_dialog=True)
Emit new AgentRequest → "next" topic
└── (picked up again by consumer, loop continues)
```
The key insight is that this loop structure is not ReACT-specific. The
plumbing — receive message, do work, emit next message — is the same
regardless of what the "work" step does. The payload and the pattern logic
define the behaviour; the infrastructure remains constant.
### Current Limitations
- Only one execution pattern (ReACT) is available regardless of task
characteristics.
- No mechanism for one agent to spawn and coordinate subagents.
- Pattern selection is implicit — every task gets the same treatment.
- The provenance model assumes a linear iteration chain (analysis N derives
from analysis N-1), with no support for parallel branches.
---
## Design Goals
- **Pattern-agnostic iteration infrastructure**: The self-queuing loop, tool
filtering, provenance emission, and streaming feedback should be shared
across all patterns.
- **Graph-constrained pattern selection**: The LLM selects patterns from a
graph-defined set, not from unconstrained reasoning. This makes the
selection auditable and explainable.
- **Genuinely parallel fan-out**: Subagent tasks execute concurrently on the
Pulsar queue, not sequentially in a single process.
- **Stateless coordination**: Fan-in uses the knowledge graph as coordination
substrate. The agent manager remains stateless.
- **Additive change**: The existing ReACT flow continues to work
unchanged. New patterns are added alongside it, not in place of it.
---
## Patterns
### ReACT as One Pattern Among Many
ReACT is one point in a wider space of agent execution strategies:
| Pattern | Structure | Strengths |
|---|---|---|
| **ReACT** | Interleaved reasoning and action | Adaptive, good for open-ended tasks |
| **Plan-then-Execute** | Decompose into a step DAG, then execute | More predictable, auditable plan |
| **Reflexion** | ReACT + self-critique after each action | Agents improve within the episode |
| **Supervisor/Subagent** | One agent orchestrates others | Parallel decomposition, synthesis |
| **Debate/Ensemble** | Multiple agents reason independently | Diverse perspectives, reconciliation |
| **LLM-as-router** | No reasoning loop, pure dispatch | Fast classification and routing |
Not all of these need to be implemented at once. The architecture should
support them; the initial implementation delivers ReACT (already exists),
Plan-then-Execute, and Supervisor/Subagent.
### Pattern Storage
Patterns are stored as configuration items via the config API. They are
finite in number, mechanically well-defined, have enumerable properties,
and change slowly. Each pattern is a JSON object stored under the
`agent-pattern` config type.
```json
Config type: "agent-pattern"
Config key: "react"
Value: {
"name": "react",
"description": "ReACT — Reasoning + Acting",
"when_to_use": "Adaptive, good for open-ended tasks"
}
```
These are written at deployment time and change rarely. If the architecture
later benefits from graph-based pattern storage (e.g. for richer ontological
relationships), the config items can be migrated to graph nodes — the
meta-router's selection logic is the same regardless of backend.
---
## Task Types
### What a Task Type Represents
A **task type** characterises the problem domain — what the agent is being
asked to accomplish, and how a domain expert would frame it analytically.
- Carries domain-specific methodology (e.g. "intelligence analysis always
applies structured analytic techniques")
- Pre-populates initial reasoning context via a framing prompt
- Constrains which patterns are valid for this class of problem
- Can define domain-specific termination criteria
### Identification
Task types are identified from plain-text task descriptions by the
LLM. Building a formal ontology over task descriptions is premature — natural
language is too varied and context-dependent. The LLM reads the description;
the graph provides the structure downstream.
### Task Type Storage
Task types are stored as configuration items via the config API under the
`agent-task-type` config type. Each task type is a JSON object that
references valid patterns by name.
```json
Config type: "agent-task-type"
Config key: "risk-assessment"
Value: {
"name": "risk-assessment",
"description": "Due Diligence / Risk Assessment",
"framing_prompt": "Analyse across financial, reputational, legal and operational dimensions using structured analytic techniques.",
"valid_patterns": ["supervisor", "plan-then-execute", "react"],
"when_to_use": "Multi-dimensional analysis requiring structured assessment"
}
```
The `valid_patterns` list defines the constrained decision space — the LLM
can only select patterns that the task type's configuration says are valid.
This is the many-to-many relationship between task types and patterns,
expressed as configuration rather than graph edges.
### Selection Flow
```
Task Description (plain text, from AgentRequest.question)
│ [LLM interprets, constrained by available task types from config]
Task Type (config item — domain framing and methodology)
│ [config lookup — valid_patterns list]
Pattern Candidates (config items)
│ [LLM selects within constrained set,
│ informed by task description signals:
│ complexity, urgency, scope]
Selected Pattern
```
The task description may carry modulating signals (complexity, urgency, scope)
that influence which pattern is selected within the constrained set. But the
raw description never directly selects a pattern — it always passes through
the task type layer first.
---
## Explainability Through Constrained Decision Spaces
A central principle of TrustGraph's explainability architecture is that
**explainability comes from constrained decision spaces**.
When a decision is made from an unconstrained space — a raw LLM call with no
guardrails — the reasoning is opaque even if the LLM produces a rationale,
because that rationale is post-hoc and unverifiable.
When a decision is made from a **constrained set defined in configuration**,
you can always answer:
- What valid options were available
- What criteria narrowed the set
- What signal made the final selection within that set
This principle already governs the existing decision trace architecture and
extends naturally to pattern selection. The routing decision — which task type
and which pattern — is itself recorded as a provenance node, making the first
decision in the execution trace auditable.
**Trust becomes a structural property of the architecture, not a claimed
property of the model.**
---
## Orchestration Architecture
### The Meta-Router
The meta-router is the entry point for all agent requests. It runs as a
pre-processing step before the pattern-specific iteration loop begins. Its
job is to determine the task type and select the execution pattern.
**When it runs**: On receipt of an `AgentRequest` with empty history (i.e. a
new task, not a continuation). Requests with non-empty history are already
mid-iteration and bypass the meta-router.
**What it does**:
1. Lists all available task types from the config API
(`config.list("agent-task-type")`).
2. Presents these to the LLM alongside the task description. The LLM
identifies which task type applies (or "general" as a fallback).
3. Reads the selected task type's configuration to get the `valid_patterns`
list.
4. Loads the candidate pattern definitions from config and presents them to
the LLM. The LLM selects one, influenced by signals in the task
description (complexity, number of independent dimensions, urgency).
5. Records the routing decision as a provenance node (see Provenance Model
below).
6. Populates the `AgentRequest` with the selected pattern, task type framing
prompt, and any pattern-specific configuration, then emits it onto the
queue.
**Where it lives**: The meta-router is a phase within the agent-orchestrator,
not a separate service. The agent-orchestrator is a new executable that
uses the same service identity as the existing agent-manager-react, making
it a drop-in replacement on the same Pulsar queues. It includes the full
ReACT implementation alongside the new orchestration patterns. The
distinction between "route" and "iterate" is determined by whether the
request already has a pattern set.
### Pattern Dispatch
Once the meta-router has annotated the request with a pattern, the agent
manager dispatches to the appropriate pattern implementation. This is a
straightforward branch on the pattern field:
```
request arrives
├── history is empty → meta-router → annotate with pattern → re-emit
└── history is non-empty (or pattern is set)
├── pattern = "react" → ReACT iteration
├── pattern = "plan-then-execute" → PtE iteration
├── pattern = "supervisor" → Supervisor iteration
└── (no pattern) → ReACT iteration (default)
```
Each pattern implementation follows the same contract: receive a request, do
one iteration of work, then either emit a "next" message (continue) or emit a
response (done). The self-queuing loop doesn't change.
### Pattern Implementations
#### ReACT (Existing)
No changes. The existing `AgentManager.react()` path continues to work
as-is.
#### Plan-then-Execute
Two-phase pattern:
**Planning phase** (first iteration):
- LLM receives the question plus task type framing.
- Produces a structured plan: an ordered list of steps, each with a goal,
expected tool, and dependencies on prior steps.
- The plan is recorded in the history as a special "plan" step.
- Emits a "next" message to begin execution.
**Execution phase** (subsequent iterations):
- Reads the plan from history.
- Identifies the next unexecuted step.
- Executes that step (tool call + observation), similar to a single ReACT
action.
- Records the result against the plan step.
- If all steps complete, synthesises a final answer.
- If a step fails or produces unexpected results, the LLM can revise the
remaining plan (bounded re-planning, not a full restart).
The plan lives in the history, so it travels with the message. No external
state is needed.
#### Supervisor/Subagent
The supervisor pattern introduces fan-out and fan-in. This is the most
architecturally significant addition.
**Supervisor planning iteration**:
- LLM receives the question plus task type framing.
- Decomposes the task into independent subagent goals.
- For each subagent, emits a new `AgentRequest` with:
- A focused question (the subagent's specific goal)
- A shared correlation ID tying it to the parent task
- The subagent's own pattern (typically ReACT, but could be anything)
- Relevant context sliced from the parent request
**Subagent execution**:
- Each subagent request is picked up by an agent manager instance and runs its
own independent iteration loop.
- Subagents are ordinary agent executions — they self-queue, use tools, emit
provenance, stream feedback.
- When a subagent reaches a Final answer, it writes a completion record to the
knowledge graph under the shared correlation ID.
**Fan-in and synthesis**:
- An aggregator detects when all sibling subagents for a correlation ID have
completed.
- It emits a synthesis request to the supervisor carrying the correlation ID.
- The supervisor queries the graph for subagent results, reasons across them,
and decides whether to emit a final answer or iterate again.
**Supervisor re-iteration**:
- After synthesis, the supervisor may determine that the results are
incomplete, contradictory, or reveal gaps requiring further investigation.
- Rather than emitting a final answer, it can fan out again with new or
refined subagent goals under a new correlation ID. This is the same
self-queuing loop — the supervisor emits new subagent requests and stops,
the aggregator detects completion, and synthesis runs again.
- The supervisor's iteration count (planning + synthesis rounds) is bounded
to prevent unbounded looping.
This is detailed further in the Fan-Out / Fan-In section below.
---
## Message Schema Evolution
### Shared Schema Principle
The `AgentRequest` and `AgentResponse` schemas are the shared contract
between the agent-manager (existing ReACT execution) and the
agent-orchestrator (meta-routing, supervisor, plan-then-execute). Both
services consume from the same *agent request* topic using the same
schema. Any schema changes must be reflected in both — the schema is
the integration point, not the service implementation.
This means the orchestrator does not introduce separate message types for
its own use. Subagent requests, synthesis triggers, and meta-router
outputs are all `AgentRequest` messages with different field values. The
agent-manager ignores orchestration fields it doesn't use.
### New Fields
The `AgentRequest` schema needs new fields to carry orchestration
metadata.
```python
@dataclass
class AgentRequest:
# Existing fields (unchanged)
question: str = ""
state: str = ""
group: list[str] | None = None
history: list[AgentStep] = field(default_factory=list)
user: str = ""
collection: str = "default"
streaming: bool = False
session_id: str = ""
# New orchestration fields
conversation_id: str = "" # Optional caller-generated ID grouping related requests
pattern: str = "" # "react", "plan-then-execute", "supervisor", ""
task_type: str = "" # Identified task type name
framing: str = "" # Task type framing prompt injected into LLM context
correlation_id: str = "" # Shared ID linking subagents to parent
parent_session_id: str = "" # Parent's session_id (for subagents)
subagent_goal: str = "" # Focused goal for this subagent
expected_siblings: int = 0 # How many sibling subagents exist
```
The `AgentStep` schema also extends to accommodate non-ReACT iteration types:
```python
@dataclass
class AgentStep:
# Existing fields (unchanged)
thought: str = ""
action: str = ""
arguments: dict[str, str] = field(default_factory=dict)
observation: str = ""
user: str = ""
# New fields
step_type: str = "" # "react", "plan", "execute", "supervise", "synthesise"
plan: list[PlanStep] | None = None # For plan-then-execute: the structured plan
subagent_results: dict | None = None # For supervisor: collected subagent outputs
```
The `PlanStep` structure for Plan-then-Execute:
```python
@dataclass
class PlanStep:
goal: str = "" # What this step should accomplish
tool_hint: str = "" # Suggested tool (advisory, not binding)
depends_on: list[int] = field(default_factory=list) # Indices of prerequisite steps
status: str = "pending" # "pending", "complete", "failed", "revised"
result: str = "" # Observation from execution
```
---
## Fan-Out and Fan-In
### Why This Matters
Fan-out is the mechanism that makes multi-agent coordination genuinely
parallel rather than simulated. With Pulsar, emitting multiple messages means
multiple consumers can pick them up concurrently. This is not threading or
async simulation — it is real distributed parallelism across agent manager
instances.
### Fan-Out: Supervisor Emits Subagent Requests
When a supervisor iteration decides to decompose a task, it:
1. Generates a **correlation ID** — a UUID that groups the sibling subagents.
2. For each subagent, constructs a new `AgentRequest`:
- `question` = the subagent's focused goal (from `subagent_goal`)
- `correlation_id` = the shared correlation ID
- `parent_session_id` = the supervisor's session_id
- `pattern` = typically "react", but the supervisor can specify any pattern
- `session_id` = a new unique ID for this subagent's own provenance chain
- `expected_siblings` = total number of sibling subagents
- `history` = empty (fresh start, but framing context inherited)
- `group`, `user`, `collection` = inherited from parent
3. Emits each subagent request onto the agent request topic.
4. Records the fan-out decision in the provenance graph (see below).
The supervisor then **stops**. It does not wait. It does not poll. It has
emitted its messages and its iteration is complete. The graph and the
aggregator handle the rest.
### Fan-In: Graph-Based Completion Detection
When a subagent reaches its Final answer, it writes a **completion node** to
the knowledge graph:
```
Completion node:
rdf:type tg:SubagentCompletion
tg:correlationId <shared correlation ID>
tg:subagentSessionId <this subagent's session_id>
tg:parentSessionId <supervisor's session_id>
tg:subagentGoal <what this subagent was asked to do>
tg:result → <document URI in librarian>
prov:wasGeneratedBy → <this subagent's conclusion entity>
```
The **aggregator** is a component that watches for completion nodes. When it
detects that all expected siblings for a correlation ID have written
completion nodes, it:
1. Collects all sibling results from the graph and librarian.
2. Constructs a **synthesis request** — a new `AgentRequest` addressed to the supervisor flow:
- `session_id` = the original supervisor's session_id
- `pattern` = "supervisor"
- `step_type` = "synthesise" (carried in history)
- `subagent_results` = the collected findings
- `history` = the supervisor's history up to the fan-out point, plus the synthesis step
3. Emits this onto the agent request topic.
The supervisor picks this up, reasons across the aggregated findings, and
produces its final answer.
### Aggregator Design
The aggregator is event-driven, consistent with TrustGraph's Pulsar-based
architecture. Polling would be an anti-pattern in a system where all
coordination is message-driven.
**Mechanism**: The aggregator is a Pulsar consumer on the explainability
topic. Subagent completion nodes are emitted as triples on this topic as
part of the existing provenance flow. When the aggregator receives a
`tg:SubagentCompletion` triple, it:
1. Extracts the `tg:correlationId` from the completion node.
2. Queries the graph to count how many siblings for that correlation ID
have completed.
3. If all `expected_siblings` are present, triggers fan-in immediately —
collects results and emits the synthesis request.
**State**: The aggregator is stateless in the same sense as the agent
manager — it holds no essential in-memory state. The graph is the source
of truth for completion counts. If the aggregator restarts, it can
re-process unacknowledged completion messages from Pulsar and re-check the
graph. No coordination state is lost.
**Consistency**: Because the completion check queries the graph rather than
relying on an in-memory counter, the aggregator is tolerant of duplicate
messages, out-of-order delivery, and restarts. The graph query is
idempotent — asking "are all siblings complete?" gives the same answer
regardless of how many times or in what order the events arrive.
### Timeout and Failure
- **Subagent timeout**: The aggregator records the timestamp of the first
sibling completion (from the graph). A periodic timeout check (the one
concession to polling — but over local state, not the graph) detects
stalled correlation IDs. If `expected_siblings` completions are not
reached within a configurable timeout, the aggregator emits a partial
synthesis request with whatever results are available, flagging the
incomplete subagents.
- **Subagent failure**: If a subagent errors out, it writes an error
completion node (with `tg:status = "error"` and an error message). The
aggregator treats this as a completion — the supervisor receives the error
in its synthesis input and can reason about partial results.
- **Supervisor iteration limit**: The supervisor's own iteration count
(planning + synthesis) is bounded by `max_iterations` just like any other
pattern.
---
## Provenance Model Extensions
### Routing Decision
The meta-router's task type and pattern selection is recorded as the first
provenance node in the session:
```
Routing node:
rdf:type prov:Entity, tg:RoutingDecision
prov:wasGeneratedBy → session (Question) activity
tg:taskType → TaskType node URI
tg:selectedPattern → Pattern node URI
tg:candidatePatterns → [Pattern node URIs] (what was available)
tg:routingRationale → document URI in librarian (LLM's reasoning)
```
This captures the constrained decision space: what candidates existed, which
was selected, and why. The candidates are graph-derived; the rationale is
LLM-generated but verifiable against the candidates.
### Fan-Out Provenance
When a supervisor fans out, the provenance records the decomposition:
```
FanOut node:
rdf:type prov:Entity, tg:FanOut
prov:wasDerivedFrom → supervisor's routing or planning iteration
tg:correlationId <correlation ID>
tg:subagentGoals → [document URIs for each subagent goal]
tg:expectedSiblings <count>
```
Each subagent's provenance chain is independent (its own session, iterations,
conclusion) but linked back to the parent via:
```
Subagent session:
rdf:type prov:Activity, tg:Question, tg:AgentQuestion
tg:parentCorrelationId <correlation ID>
tg:parentSessionId <supervisor session URI>
```
### Fan-In Provenance
The synthesis step links back to all subagent conclusions:
```
Synthesis node:
rdf:type prov:Entity, tg:Synthesis
prov:wasDerivedFrom → [all subagent Conclusion entities]
tg:correlationId <correlation ID>
```
This creates a DAG in the provenance graph: the supervisor's routing fans out
to N parallel subagent chains, which fan back in to a synthesis node. The
entire multi-agent execution is traceable from a single correlation ID.
### URI Scheme
Extending the existing `urn:trustgraph:agent:{session_id}` pattern:
| Entity | URI Pattern |
|---|---|
| Session (existing) | `urn:trustgraph:agent:{session_id}` |
| Iteration (existing) | `urn:trustgraph:agent:{session_id}/i{n}` |
| Conclusion (existing) | `urn:trustgraph:agent:{session_id}/answer` |
| Routing decision | `urn:trustgraph:agent:{session_id}/routing` |
| Fan-out record | `urn:trustgraph:agent:{session_id}/fanout/{correlation_id}` |
| Subagent completion | `urn:trustgraph:agent:{session_id}/completion` |
---
## Storage Responsibilities
Pattern and task type definitions live in the config API. Runtime state and
provenance live in the knowledge graph. The division is:
| Role | Storage | When Written | Content |
|---|---|---|---|
| Pattern definitions | Config API | At design time | Pattern properties, descriptions |
| Task type definitions | Config API | At design time | Domain framing, valid pattern lists |
| Routing decision trace | Knowledge graph | At request arrival | Why this task type and pattern were selected |
| Iteration decision trace | Knowledge graph | During execution | Each think/act/observe cycle, per existing model |
| Fan-out coordination | Knowledge graph | During fan-out | Subagent goals, correlation ID, expected count |
| Subagent completion | Knowledge graph | During fan-in | Per-subagent results under shared correlation ID |
| Execution audit trail | Knowledge graph | Post-execution | Full multi-agent reasoning trace as a DAG |
The config API holds the definitions that constrain decisions. The knowledge
graph holds the runtime decisions and their provenance. The fan-in
coordination state is part of the provenance automatically — subagent
completion nodes are both coordination signals and audit trail entries.
---
## Worked Example: Partner Risk Assessment
**Request**: "Assess the risk profile of Company X as a potential partner"
**1. Request arrives** on the *agent request* topic with empty history.
The agent manager picks it up.
**2. Meta-router**:
- Queries config API, finds task types: *Risk Assessment*, *Research*,
*Summarisation*, *General*.
- LLM identifies *Risk Assessment*. Framing prompt loaded: "analyse across
financial, reputational, legal and operational dimensions using structured
analytic techniques."
- Valid patterns for *Risk Assessment*: [*Supervisor/Subagent*,
*Plan-then-Execute*, *ReACT*].
- LLM selects *Supervisor/Subagent* — task has four independent investigative
dimensions, well-suited to parallel decomposition.
- Routing decision written to graph. Request re-emitted on the
*agent request* topic with `pattern="supervisor"`, framing populated.
**3. Supervisor iteration** (picked up from *agent request* topic):
- LLM receives question + framing. Reasons that four independent investigative
threads are required.
- Generates correlation ID `corr-abc123`.
- Emits four subagent requests on the *agent request* topic:
- Financial analysis (pattern="react", subagent_goal="Analyse financial
health and stability of Company X")
- Legal analysis (pattern="react", subagent_goal="Review regulatory filings,
sanctions, and legal exposure for Company X")
- Reputational analysis (pattern="react", subagent_goal="Analyse news
sentiment and public reputation of Company X")
- Operational analysis (pattern="react", subagent_goal="Assess supply chain
dependencies and operational risks for Company X")
- Fan-out node written to graph.
**4. Four subagents run in parallel** (each picked up from the *agent
request* topic by agent manager instances), each as an independent ReACT
loop:
- Financial — queries financial data services and knowledge graph
relationships
- Legal — searches regulatory filings and sanctions lists
- Reputational — searches news, analyses sentiment
- Operational — queries supply chain databases
Each self-queues its iterations on the *agent request* topic. Each writes
its own decision trace to the graph as it progresses. Each completes
independently.
**5. Fan-in**:
- Each subagent writes a `tg:SubagentCompletion` node to the graph on
completion, emitted on the *explainability* topic. The completion node
references the subagent's result document in the librarian.
- Aggregator (consuming the *explainability* topic) sees each completion
event. It queries the graph for the fan-out node to get the expected
sibling count, then checks how many completions exist for
`corr-abc123`.
- When all four siblings are complete, the aggregator emits a synthesis
request on the *agent request* topic with the correlation ID. It does
not fetch or bundle subagent results — the supervisor will query the
graph for those.
**6. Supervisor synthesis** (picked up from *agent request* topic):
- Receives the synthesis trigger carrying the correlation ID.
- Queries the graph for `tg:SubagentCompletion` nodes under
`corr-abc123`, retrieving each subagent's goal and result document
reference.
- Fetches the result documents from the librarian.
- Reasons across all four dimensions, produces a structured risk
assessment with confidence scores.
- Emits final answer on the *agent response* topic and writes conclusion
provenance to the graph.
**7. Response delivered** — the supervisor's synthesis streams on the
*agent response* topic as the LLM generates it, with `end_of_dialog`
on the final chunk. The collated answer is saved to the librarian and
referenced from conclusion provenance in the graph. The graph now holds
a complete, human-readable trace of the entire multi-agent execution —
from pattern selection through four parallel investigations to final
synthesis.
---
## Class Hierarchy
The agent-orchestrator executable (`agent-orchestrator`) uses the same
service identity as agent-manager-react, making it a drop-in replacement.
The pattern dispatch model suggests a class hierarchy where shared iteration
infrastructure lives in a base class and pattern-specific logic is in
subclasses:
```
AgentService (base — Pulsar consumer/producer specs, request handling)
└── Processor (agent-orchestrator service)
├── MetaRouter — task type identification, pattern selection
├── PatternBase — shared: tool filtering, provenance, streaming, history
│ ├── ReactPattern — existing ReACT logic (extract from current AgentManager)
│ ├── PlanThenExecutePattern — plan phase + execute phase
│ └── SupervisorPattern — fan-out, synthesis
└── Aggregator — fan-in completion detection
```
`PatternBase` captures what is currently spread across `Processor` and
`AgentManager`: tool filtering, LLM invocation, provenance triple emission,
streaming callbacks, history management. The pattern subclasses implement only
the decision logic specific to their execution strategy — what to do with the
LLM output, when to terminate, whether to fan out.
This refactoring is not strictly necessary for the first iteration — the
meta-router and pattern dispatch could be added as branches within the
existing `Processor.agent_request()` method. But the class hierarchy clarifies
where shared vs. pattern-specific logic lives and will prevent duplication as
more patterns are added.
---
## Configuration
### Config API Seeding
Pattern and task type definitions are stored via the config API and need to
be seeded at deployment time. This is analogous to how flow blueprints and
parameter types are loaded — a bootstrap step that writes the initial
configuration.
The initial seed includes:
**Patterns** (config type `agent-pattern`):
- `react` — interleaved reasoning and action
- `plan-then-execute` — structured plan followed by step execution
- `supervisor` — decomposition, fan-out to subagents, synthesis
**Task types** (config type `agent-task-type`, initial set, expected to grow):
- `general` — no specific domain framing, all patterns valid
- `research` — open-ended investigation, valid patterns: react, plan-then-execute
- `risk-assessment` — multi-dimensional analysis, valid patterns: supervisor,
plan-then-execute, react
- `summarisation` — condense information, valid patterns: react
The seed data is configuration, not code. It can be extended via the config
API (or the configuration UI) without redeploying the agent manager.
### Migration Path
The config API provides a practical starting point. If richer ontological
relationships between patterns, task types, and domain knowledge become
valuable, the definitions can be migrated to graph storage. The meta-router's
selection logic queries an abstract set of task types and patterns — the
storage backend is an implementation detail.
### Fallback Behaviour
If the config contains no patterns or task types:
- Task type defaults to `general`.
- Pattern defaults to `react`.
- The system degrades gracefully to existing behaviour.
---
## Design Decisions
| Decision | Resolution | Rationale |
|---|---|---|
| Task type identification | LLM interprets from plain text | Natural language too varied to formalise prematurely |
| Pattern/task type storage | Config API initially, graph later if needed | Avoids graph model complexity upfront; config API already has UI support; migration path is straightforward |
| Meta-router location | Phase within agent manager, not separate service | Avoids an extra network hop; routing is fast |
| Fan-in mechanism | Event-driven via explainability topic | Consistent with Pulsar-based architecture; graph query for completion count is idempotent and restart-safe |
| Aggregator deployment | Separate lightweight process | Decoupled from agent manager lifecycle |
| Subagent pattern selection | Supervisor specifies per-subagent | Supervisor has task context to make this choice |
| Plan storage | In message history | No external state needed; plan travels with message |
| Default pattern | Empty pattern field → ReACT | Sensible default when meta-router is not configured |
---
## Streaming Protocol
### Current Model
The existing agent response schema has two levels:
- **`end_of_message`** — marks the end of a complete thought, observation,
or answer. Chunks belonging to the same message arrive sequentially.
- **`end_of_dialog`** — marks the end of the entire agent execution. No
more messages will follow.
This works because the current system produces messages serially — one
thought at a time, one agent at a time.
### Fan-Out Breaks Serial Assumptions
With supervisor/subagent fan-out, multiple subagents stream chunks
concurrently on the same *agent response* topic. The caller receives
interleaved chunks from different sources and needs to demultiplex them.
### Resolution: Message ID
Each chunk carries a `message_id` — a per-message UUID generated when
the agent begins streaming a new thought, observation, or answer. The
caller groups chunks by `message_id` and assembles each message
independently.
```
Response chunk fields:
message_id UUID for this message (groups chunks)
session_id Which agent session produced this chunk
chunk_type "thought" | "observation" | "answer" | ...
content The chunk text
end_of_message True on the final chunk of this message
end_of_dialog True on the final message of the entire execution
```
A single subagent emits multiple messages (thought, observation, thought,
answer), each with a distinct `message_id`. The `session_id` identifies
which subagent the message belongs to. The caller can display, group, or
filter by either.
### Provenance Trigger
`end_of_message` is the trigger for provenance storage. When a complete
message has been assembled from its chunks:
1. The collated text is saved to the librarian as a single document.
2. A provenance node is written to the graph referencing the document URI.
This follows the pattern established by GraphRAG, where streaming synthesis
chunks are delivered live but the stored provenance references the collated
answer text. Streaming is for the caller; provenance needs complete messages.
---
## Open Questions
- **Re-planning depth** (resolved): Runtime parameter on the
agent-orchestrator executable, default 2. Bounds how many times
Plan-then-Execute can revise its plan before forcing termination.
- **Nested fan-out** (phase B): A subagent can itself be a supervisor
that fans out further. The architecture supports this — correlation IDs
are independent and the aggregator is stateless. The protocols and
message schema should not preclude nested fan-out, but implementation
is deferred. Depth limits will need to be enforced to prevent runaway
decomposition.
- **Task type evolution** (resolved): Manually curated for now. See
Future Directions below for automated discovery.
- **Cost attribution** (deferred): Costs are measured at the
text-completion queue level as they are today. Per-request attribution
across subagents is not yet implemented and is not a blocker for
orchestration.
- **Conversation ID** (resolved): An optional `conversation_id` field on
`AgentRequest`, generated by the caller. When present, all objects
created during the execution (provenance nodes, librarian documents,
subagent completion records) are tagged with the conversation ID. This
enables querying all interactions in a conversation with a single
lookup, and provides the foundation for conversation-scoped memory.
No explicit open/close — the first request with a new conversation ID
implicitly starts the conversation. Omit for one-shot queries.
- **Tool scoping per subagent** (resolved): Subagents inherit the
parent's tool group by default. The supervisor can optionally override
the group per subagent to constrain capabilities (e.g. financial
subagent gets only financial tools). The `group` field on
`AgentRequest` already supports this — the supervisor just sets it
when constructing subagent requests.
---
## Future Directions
### Automated Task Type Discovery
Task types are manually curated in the initial implementation. However,
the architecture is well-suited to automated discovery because all agent
requests and their execution traces flow through Pulsar topics. A
learning service could consume these messages and analyse patterns in
how tasks are framed, which patterns are selected, and how successfully
they execute. Over time, it could propose new task types based on
clusters of similar requests that don't map well to existing types, or
suggest refinements to framing prompts based on which framings lead to
better outcomes. This service would write proposed task types to the
config API for human review — automated discovery, manual approval. The
agent-orchestrator does not need to change; it always reads task types
from config regardless of how they got there.

View file

@ -0,0 +1,282 @@
# Config Push "Notify" Pattern Technical Specification
## Overview
Replace the current config push mechanism — which broadcasts the full config
blob on a `state` class queue — with a lightweight "notify" notification
containing only the version number and affected types. Processors that care
about those types fetch the full config via the existing request/response
interface.
This solves the RabbitMQ late-subscriber problem: when a process restarts,
its fresh queue has no historical messages, so it never receives the current
config state. With the notify pattern, the push queue is only a signal — the
source of truth is the config service's request/response API, which is
always available.
## Problem
On Pulsar, `state` class queues are persistent topics. A new subscriber
with `InitialPosition.Earliest` reads from message 0 and receives the
last config push. On RabbitMQ, each subscriber gets a fresh per-subscriber
queue (named with a new UUID). Messages published before the queue existed
are gone. A restarting processor never gets the current config.
## Design
### The Notify Message
The `ConfigPush` schema changes from carrying the full config to carrying
just a version number and the list of affected config types:
```python
@dataclass
class ConfigPush:
version: int = 0
types: list[str] = field(default_factory=list)
```
When the config service handles a `put` or `delete`, it knows which types
were affected (from the request's `values[].type` or `keys[].type`). It
includes those in the notify. On startup, the config service sends a notify
with an empty types list (meaning "everything").
### Subscribe-then-Fetch Startup (No Race Condition)
The critical ordering to avoid missing an update:
1. **Subscribe** to the config push queue. Buffer incoming notify messages.
2. **Fetch** the full config via request/response (`operation: "config"`).
This returns the config dict and a version number.
3. **Apply** the fetched config to all registered handlers.
4. **Process** buffered notifys. For any notify with `version > fetched_version`,
re-fetch and re-apply. Discard notifys with `version <= fetched_version`.
5. **Enter steady state**. Process future notifys as they arrive.
This is safe because:
- If an update happens before the subscription, the fetch picks it up.
- If an update happens between subscribe and fetch, it's in the buffer.
- If an update happens after the fetch, it arrives on the queue normally.
- Version comparison ensures no duplicate processing.
### Processor API
The current API requires processors to understand the full config dict
structure. The new API should be cleaner — processors declare which config
types they care about and provide a handler that receives only the relevant
config subset.
#### Current API
```python
# In processor __init__:
self.register_config_handler(self.on_configure_flows)
# Handler receives the entire config dict:
async def on_configure_flows(self, config, version):
if "active-flow" not in config:
return
if self.id in config["active-flow"]:
flow_config = json.loads(config["active-flow"][self.id])
# ...
```
#### New API
```python
# In processor __init__:
self.register_config_handler(
handler=self.on_configure_flows,
types=["active-flow"],
)
# Handler receives only the relevant config subset, same signature:
async def on_configure_flows(self, config, version):
# config still contains the full dict, but handler is only called
# when "active-flow" type changes (or on startup)
if "active-flow" not in config:
return
# ...
```
The `types` parameter is optional. If omitted, the handler is called for
every config change (backward compatible). If specified, the handler is
only invoked when the notify's `types` list intersects with the handler's
types, or on startup (empty types list = everything).
#### Internal Registration Structure
```python
# In AsyncProcessor:
def register_config_handler(self, handler, types=None):
self.config_handlers.append({
"handler": handler,
"types": set(types) if types else None, # None = all types
})
```
#### Notify Processing Logic
```python
async def on_config_notify(self, message, consumer, flow):
notify_version = message.value().version
notify_types = set(message.value().types)
# Skip if we already have this version or newer
if notify_version <= self.config_version:
return
# Fetch full config from config service
config, version = await self.config_client.config()
self.config_version = version
# Determine which handlers to invoke
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types is None:
# Handler cares about everything
await entry["handler"](config, version)
elif not notify_types or notify_types & handler_types:
# notify_types empty = startup (invoke all),
# or intersection with handler's types
await entry["handler"](config, version)
```
### Config Service Changes
#### Push Method
The `push()` method changes to send only version + types:
```python
async def push(self, types=None):
version = await self.config.get_version()
resp = ConfigPush(
version=version,
types=types or [],
)
await self.config_push_producer.send(resp)
```
#### Put/Delete Handlers
Extract affected types and pass to push:
```python
async def handle_put(self, v):
types = list(set(k.type for k in v.values))
for k in v.values:
await self.table_store.put_config(k.type, k.key, k.value)
await self.inc_version()
await self.push(types=types)
async def handle_delete(self, v):
types = list(set(k.type for k in v.keys))
for k in v.keys:
await self.table_store.delete_key(k.type, k.key)
await self.inc_version()
await self.push(types=types)
```
#### Queue Class Change
The config push queue changes from `state` class to `flow` class. The push
is now a transient signal — the source of truth is the config service's
request/response API, not the queue. `flow` class is persistent (survives
broker restarts) but doesn't require last-message retention, which was the
root cause of the RabbitMQ problem.
```python
config_push_queue = queue('config', cls='flow') # was cls='state'
```
#### Startup Push
On startup, the config service sends a notify with empty types list
(signalling "everything changed"):
```python
async def start(self):
await self.push(types=[]) # Empty = all types
await self.config_request_consumer.start()
```
### AsyncProcessor Changes
The `AsyncProcessor` needs a config request/response client alongside the
push consumer. The startup sequence becomes:
```python
async def start(self):
# 1. Start the push consumer (begins buffering notifys)
await self.config_sub_task.start()
# 2. Fetch current config via request/response
config, version = await self.config_client.config()
self.config_version = version
# 3. Apply to all handlers (startup = all handlers invoked)
for entry in self.config_handlers:
await entry["handler"](config, version)
# 4. Buffered notifys are now processed by on_config_notify,
# which skips versions <= self.config_version
```
The config client needs to be created in `__init__` using the existing
request/response queue infrastructure. The `ConfigClient` from
`trustgraph.clients.config_client` already exists but uses a synchronous
blocking pattern. An async variant or integration with the processor's
pub/sub backend is needed.
### Existing Config Handler Types
For reference, the config types currently used by handlers:
| Handler | Type(s) | Used By |
|---------|---------|---------|
| `on_configure_flows` | `active-flow` | All FlowProcessor subclasses |
| `on_collection_config` | `collection` | Storage services (triples, embeddings, rows) |
| `on_prompt_config` | `prompt` | Prompt template service, agent extract |
| `on_schema_config` | `schema` | Rows storage, row embeddings, NLP query, structured diag |
| `on_cost_config` | `token-costs` | Metering service |
| `on_ontology_config` | `ontology` | Ontology extraction |
| `on_librarian_config` | `librarian` | Librarian service |
| `on_mcp_config` | `mcp-tool` | MCP tool service |
| `on_knowledge_config` | `kg-core` | Cores service |
## Implementation Order
1. **Update ConfigPush schema** — change `config` field to `types` field.
2. **Update config service** — modify `push()` to send version + types.
Modify `handle_put`/`handle_delete` to extract affected types.
3. **Add async config query to AsyncProcessor** — create a
request/response client for config queries within the processor's
event loop.
4. **Implement subscribe-then-fetch startup** — reorder
`AsyncProcessor.start()` to subscribe first, then fetch, then
process buffered notifys with version comparison.
5. **Update register_config_handler** — add optional `types` parameter.
Update `on_config_notify` to filter by type intersection.
6. **Update existing handlers** — add `types` parameter to all
`register_config_handler` calls across the codebase.
7. **Backward compatibility** — handlers without `types` parameter
continue to work (invoked for all changes).
## Risks
- **Thundering herd**: if many processors restart simultaneously, they
all hit the config service API at once. Mitigated by the config service
already being designed for request/response load, and the number of
processors being small (tens, not thousands).
- **Config service availability**: processors now depend on the config
service being up at startup, not just having received a push. This is
already the case in practice — without config, processors can't do
anything useful.

View file

@ -0,0 +1,551 @@
# Pub/Sub Abstraction: Broker-Independent Messaging
## Problem
TrustGraph's messaging infrastructure is deeply coupled to Apache Pulsar in ways that go beyond the transport layer. This coupling creates several concrete problems.
### 1. Schema system is Pulsar-native
Every message type in the system is defined as a `pulsar.schema.Record` subclass using Pulsar field types (`String()`, `Integer()`, `Boolean()`, etc.). This means:
- The `pulsar` Python package is a build dependency for `trustgraph-base`, even though `trustgraph-base` contains no transport logic
- Any code that imports a message schema transitively depends on Pulsar
- The schema definitions cannot be reused with a different broker without the Pulsar library installed
- What's actually happening on the wire is JSON serialisation — the Pulsar schema machinery adds complexity without adding value over plain JSON encode/decode
### 2. Translators are named after the broker
The translator layer that converts between internal Python objects and wire format uses methods called `to_pulsar()` and `from_pulsar()`. These are really just JSON encode/decode operations — they have nothing to do with Pulsar specifically. The naming creates a false impression that the translation is broker-specific, when in reality any broker that carries JSON payloads would use identical logic.
### 3. Queue names use Pulsar URI format
Queue identifiers throughout the codebase use Pulsar's `persistent://tenant/namespace/topic` or `non-persistent://tenant/namespace/topic` URI format. These are hardcoded in schema definitions and referenced across services. RabbitMQ, Redis Streams, or any other broker would use completely different naming conventions. There is no abstraction between the logical identity of a queue and its broker-specific address.
### 4. Broker selection is not configurable
There is no mechanism to select a different pub/sub backend at deployment time. The Pulsar client is instantiated directly in the gateway and via `PulsarClient` in the base processor. Switching to a different broker would require code changes across multiple packages, not a configuration change.
### 5. Architectural requirements are implicit
TrustGraph relies on specific pub/sub behaviours — shared subscriptions for load balancing, message acknowledgement for reliability, message properties for correlation — but these requirements are not documented. This makes it difficult to evaluate whether a candidate broker (RabbitMQ, Redis Streams, NATS, etc.) actually satisfies the system's needs, or where the gaps would be.
## Design Goals
### Goal 1: Remove the link between Pulsar schemas and application code
Message types should be plain Python objects (dataclasses) that know how to serialise to and from JSON. The `pulsar.schema.Record` base class and Pulsar field types should not appear in schema definitions. The pub/sub transport layer sends and receives JSON bytes; the schema layer handles the mapping between JSON and typed Python objects independently.
### Goal 2: Remove `to_pulsar` / `from_pulsar` naming
The translator methods should reflect what they actually do: encode a Python object to a JSON-compatible dict, and decode a JSON-compatible dict back to a Python object. The naming should be broker-neutral (e.g. `encode` / `decode`, or `to_dict` / `from_dict`).
### Goal 3: Schema objects provide encode/decode
Each message type should be a Python dataclass (or similar) with a well-defined mapping to and from JSON. For example:
```python
@dataclass
class TextCompletionRequest:
system: str
prompt: str
streaming: bool = False
```
Given `{"system": "You are helpful", "prompt": "Hello", "streaming": false}` on the wire, decoding produces an object where `request.system` is `"You are helpful"`, `request.prompt` is `"Hello"`, and `request.streaming` is `False`. Encoding does the reverse. This is the schema's concern, not the broker's.
### Goal 4: Abstract queue naming
Queue identifiers should not use Pulsar URI format (`persistent://tg/flow/topic`). A broker-neutral naming scheme is needed so that each backend can map logical queue names to its native format. The right approach here is not yet clear and needs to be worked through — considerations include how to express quality-of-service, multi-tenancy, and namespace separation without leaking broker concepts.
### Goal 5: Document pub/sub architectural requirements
TrustGraph's actual requirements from the pub/sub layer need to be formally specified. This includes:
- **Delivery semantics**: Which queues need at-least-once delivery? Are any fire-and-forget?
- **Consumer patterns**: Shared subscriptions (competing consumers for load balancing), exclusive subscriptions, fan-out/broadcast
- **Message acknowledgement**: Positive ack, negative ack (redelivery), timeout-based redelivery
- **Message properties**: Key-value metadata on messages used for correlation (e.g. request IDs, flow routing)
- **Ordering guarantees**: Per-topic ordering, per-key ordering, or no ordering required
- **Message size**: Typical and maximum message sizes (some payloads include base64-encoded documents)
- **Persistence**: Which messages must survive broker restarts
- **Consumer positioning**: Ability to consume from earliest (replay) vs latest (live tail)
- **Connection model**: Long-lived connections with reconnection, or transient
Documenting these requirements makes it possible to evaluate RabbitMQ or any other candidate against concrete criteria rather than discovering gaps during implementation.
## Pub/Sub Architectural Requirements (As-Is)
This section documents what TrustGraph currently needs from its pub/sub layer. These are the as-is requirements — some may be revisited or relaxed in a future design if it makes broker portability easier.
### Consumer model
All consumers use **shared subscriptions** (competing consumers). Multiple instances of the same processor read from the same subscription, and each message is delivered to exactly one instance. This is the load-balancing mechanism.
No exclusive or failover subscriptions are used anywhere in the codebase, despite infrastructure support for them.
Consumers support configurable concurrency — multiple async tasks within a single process can independently call `receive()` on the same subscription.
### Delivery semantics
Almost all queues are **non-persistent / best-effort (q0)**. The only persistent queue is `config_push_queue` (q2, exactly-once), which pushes full configuration state to processors. Since config pushes are idempotent (full state, not deltas), the persistence requirement here is about surviving broker restarts, not about exactly-once semantics per se.
Flow processing queues (request/response pairs for LLM, RAG, agent, etc.) are all non-persistent. Messages in flight are lost on broker restart. This is acceptable because:
- Requests originate from a client that will time out and retry
- There is no durable work-in-progress that would be corrupted by message loss
- The system is designed for real-time query processing, not batch pipelines
### Message acknowledgement
**Positive acknowledgement**: After successful handler execution, the message is acknowledged. This removes it from the subscription.
**Negative acknowledgement**: On handler failure (unhandled exception or rate-limit timeout), the message is negatively acknowledged, which triggers redelivery by the broker. Rate-limited messages retry for up to 7200 seconds before giving up and negatively acknowledging.
**Orphaned messages**: In the request-response subscriber pattern, messages that arrive with no matching waiter (e.g. the requester timed out) are positively acknowledged and discarded. This prevents redelivery storms.
### Message properties
Messages carry a small set of key-value string properties as metadata, separate from the payload. The primary use is a `"id"` property for request-response correlation — the requester generates a unique ID, attaches it as a property, and the responder echoes it back so the subscriber can match responses to waiters.
Agent orchestration correlation (`correlation_id`, `parent_session_id`) is carried in the message payload, not in properties.
### Consumer positioning
Two modes are used:
- **Earliest**: The configuration consumer starts from the beginning of the topic to receive full configuration history on startup. This is the only use of earliest positioning.
- **Latest** (default): All flow consumers start from the current position, processing only new messages.
### Message ordering
**Not required.** The codebase explicitly does not depend on message ordering:
- Shared subscriptions distribute messages across consumers without ordering guarantees
- Concurrent handler tasks within a consumer process messages in arbitrary order
- Request-response correlation uses IDs, not positional ordering
- The supervisor fan-out/fan-in pattern collects results in a dictionary, order-independent
- Configuration pushes are full state snapshots, not ordered deltas
### Message sizes
Most messages are small JSON payloads (< 10KB). The exceptions:
- **Document content**: Large documents (PDFs, text files) can be sent through the chunking service with base64 encoding. Pulsar's chunking feature (`chunking_enabled`) handles automatic splitting of oversized messages.
- **Agent observations**: LLM-generated text can be several KB but rarely exceeds typical message size limits.
A replacement broker needs to either support large messages natively or provide a chunking/streaming mechanism. Alternatively, the large-document path could be refactored to use a side-channel (e.g. object store reference) instead of inline payload.
### Fan-out patterns
**Supervisor fan-out**: One supervisor request decomposes into N independent sub-agent requests, each emitted as a separate message on the agent request queue. Different agent instances pick them up via the shared subscription. A correlation ID links the completions back to the original decomposition. This is not pub/sub fan-out (one message to many consumers) — it's application-level fan-out (many messages to one queue).
**Request-response isolation**: Each client creates a unique subscription name on response queues so it only receives its own responses. This means the response queue effectively has many independent subscribers, each seeing a filtered subset of messages based on the `"id"` property match.
### Reconnection and resilience
Reconnection logic lives in the Consumer/Producer/Publisher/Subscriber classes, not in the broker client. These classes handle:
- Automatic reconnection on connection loss
- Retry loops with backoff
- Graceful shutdown (unsubscribe, close)
The broker client itself is expected to provide a basic connection that can fail, and the wrapper classes handle recovery. This is important for the abstraction — the backend interface can be simple because resilience is handled above it.
### Queue inventory
| Queue | Persistence | Purpose |
|-------|-------------|---------|
| config push | Persistent (q2) | Full configuration state broadcast |
| config request/response | Non-persistent | Configuration queries |
| flow request/response | Non-persistent | Flow management |
| knowledge request/response | Non-persistent | Knowledge graph operations |
| librarian request/response | Non-persistent | Document storage operations |
| document embeddings request/response | Non-persistent | Document vector queries |
| row embeddings request/response | Non-persistent | Row vector queries |
| collection request/response | Non-persistent | Collection management |
Additionally, each processing service (LLM, RAG, agent, prompt, embeddings, etc.) has dynamically defined request/response queue pairs configured at deployment time.
### Summary of hard requirements for a replacement broker
1. **Shared subscription / competing consumers** — multiple consumers on one queue, each message delivered to exactly one
2. **Message acknowledgement** — positive ack (remove from queue) and negative ack (trigger redelivery)
3. **Message properties** — key-value metadata on messages, at minimum a string `"id"` field
4. **Two consumer start positions** — from beginning of topic and from current position
5. **Persistence for at least one queue** — config state must survive broker restart
6. **Messages up to several MB** — or a chunking mechanism for large payloads
7. **No ordering requirement** — simplifies broker selection significantly
## Candidate Brokers
A quick assessment of alternatives against the hard requirements above.
### RabbitMQ
The primary candidate. Mature, widely deployed, well understood.
- **Competing consumers**: Yes — multiple consumers on a queue, round-robin delivery. This is RabbitMQ's native model.
- **Acknowledgement**: Yes — `basic.ack` and `basic.nack` with requeue flag.
- **Message properties**: Yes — headers and properties on every message. The `correlation_id` and `message_id` fields are first-class concepts.
- **Consumer positioning**: Yes, via RabbitMQ Streams (3.9+). Streams are append-only logs that support reading from any offset — beginning, end, or timestamp. Classic queues are consumed destructively (no replay), but streams solve this cleanly. The `state` queue class maps to a RabbitMQ stream. Additionally, the Last Value Cache Exchange plugin can retain the most recent message per routing key for new consumers.
- **Persistence**: Yes — durable queues and persistent messages survive broker restart.
- **Large messages**: No hard limit but not designed for very large payloads. Practical limit around 128MB with default config. Adequate for current use.
- **Ordering**: FIFO per queue (stronger than required).
- **Operational complexity**: Low. Single binary, no ZooKeeper/BookKeeper dependencies. Significantly simpler to operate than Pulsar.
- **Ecosystem**: Excellent client libraries, management UI, mature tooling.
**Gaps**: None significant. RabbitMQ Streams cover the replay/earliest positioning requirement.
### Apache Kafka
High-throughput distributed log. More infrastructure than TrustGraph likely needs.
- **Competing consumers**: Yes — consumer groups with partition assignment.
- **Acknowledgement**: Yes — offset commits. No per-message negative ack; failed messages require application-level retry or dead-letter handling.
- **Message properties**: Yes — message headers (key-value byte arrays).
- **Consumer positioning**: Yes — seek to earliest or latest offset. Supports full replay.
- **Persistence**: Yes — all messages are persisted to the log by default.
- **Large messages**: Configurable (`max.message.bytes`), default 1MB, can be increased. Large payloads are discouraged by design.
- **Ordering**: Per-partition ordering (stronger than required).
- **Operational complexity**: High. Requires ZooKeeper (or KRaft), partition management, replication config. Overkill for typical TrustGraph deployments.
- **Ecosystem**: Excellent client libraries, schema registry, Connect framework.
**Gaps**: No native negative acknowledgement. Operational complexity is high for small-to-medium deployments. Partition count must be planned upfront for parallelism.
### Redis Streams
Lightweight option using Redis as a message broker.
- **Competing consumers**: Yes — consumer groups with `XREADGROUP`.
- **Acknowledgement**: Yes — `XACK`. Pending entries list tracks unacknowledged messages. No explicit negative ack but unacknowledged messages can be claimed after timeout via `XAUTOCLAIM`.
- **Message properties**: No native separation between properties and payload. Would need to encode properties as fields within the stream entry or in the payload.
- **Consumer positioning**: Yes — `0` (earliest) or `$` (latest) on group creation.
- **Persistence**: Yes — Redis persistence (RDB/AOF), though Redis is primarily an in-memory system.
- **Large messages**: Practical limit tied to Redis memory. Not suited for large payloads.
- **Ordering**: Per-stream ordering (stronger than required).
- **Operational complexity**: Low if Redis is already in the stack. No additional infrastructure.
**Gaps**: No native message properties. Memory-bound. Persistence depends on Redis configuration. Not a natural fit for message broker patterns.
### NATS / NATS JetStream
Lightweight, high-performance messaging. JetStream adds persistence.
- **Competing consumers**: Yes — queue groups in core NATS; consumer groups in JetStream.
- **Acknowledgement**: JetStream only — `Ack`, `Nak` (with redelivery), `InProgress` (extend timeout).
- **Message properties**: Yes — message headers (key-value).
- **Consumer positioning**: JetStream — deliver all, deliver last, deliver new, deliver by sequence/time.
- **Persistence**: JetStream only. Core NATS is fire-and-forget.
- **Large messages**: Default 1MB, configurable up to 64MB.
- **Ordering**: Per-subject ordering.
- **Operational complexity**: Very low. Single binary, no dependencies. Clustering is straightforward.
**Gaps**: Requires JetStream for persistence and acknowledgement. Smaller ecosystem than RabbitMQ/Kafka.
### Assessment Summary
| Requirement | RabbitMQ | Kafka | Redis Streams | NATS JetStream |
|---|---|---|---|---|
| Competing consumers | Yes | Yes | Yes | Yes |
| Positive/negative ack | Yes | Partial | Partial | Yes |
| Message properties | Yes | Yes | No | Yes |
| Earliest positioning | Yes (Streams) | Yes | Yes | Yes |
| Persistence | Yes | Yes | Partial | Yes |
| Large messages | Yes | Configurable | No | Configurable |
| Operational simplicity | Good | Poor | Good | Good |
**RabbitMQ** is the strongest candidate given TrustGraph's requirements and deployment profile. The only gap (earliest consumer positioning for config) has known workarounds. Operational simplicity is a significant advantage over Pulsar.
## Approach
### Current state
The codebase has already undergone a partial abstraction. The picture is better than the problem statement might suggest:
- **Backend abstraction exists**: `backend.py` defines Protocol-based interfaces (`PubSubBackend`, `BackendProducer`, `BackendConsumer`, `Message`). The Pulsar implementation lives in `pulsar_backend.py`.
- **Schemas are already dataclasses**: Message types in `schema/services/*.py` are plain Python dataclasses with type hints, not Pulsar `Record` subclasses. This was the hardest part of the old spec and it's done.
- **Serialization is JSON-based**: `pulsar_backend.py` contains `dataclass_to_dict()` and `dict_to_dataclass()` helpers that handle the round-trip. The wire format is JSON.
- **Factory pattern exists**: `pubsub.py` has `get_pubsub()` which creates a backend from configuration. Currently only Pulsar is implemented.
- **Consumer/Producer/Publisher/Subscriber are backend-agnostic**: These classes accept a `backend` parameter and delegate transport operations to it. They own retry, reconnection, metrics, and concurrency.
What remains is cleanup, not a rewrite.
### What needs to change
#### 1. Rename translator methods
The translator base class (`messaging/translators/base.py`) defines `to_pulsar()` and `from_pulsar()` as abstract methods. Every translator implements these. The methods convert between external API dicts and internal dataclass objects — nothing Pulsar-specific happens in them.
**Change**: Rename to `decode()` (external dict → dataclass) and `encode()` (dataclass → external dict). Update all translator subclasses and all call sites.
This is a mechanical rename. The method bodies don't change.
#### 2. Rename translator base classes
The base classes `Translator`, `MessageTranslator`, and `SendTranslator` reference "pulsar" in docstrings and parameter names. Clean these up so the naming reflects what the layer actually does: translating between the external API representation (JSON dicts from HTTP/WebSocket) and the internal schema (dataclasses).
#### 3. Move serialization out of the Pulsar backend
`dataclass_to_dict()` and `dict_to_dataclass()` currently live in `pulsar_backend.py` but are not Pulsar-specific. They handle the conversion between dataclasses and JSON-compatible dicts, which every backend needs.
**Change**: Move these to a shared location (e.g. `trustgraph/base/serialization.py` or alongside the schema definitions). The backend interface sends and receives dicts; serialization to/from dataclasses happens at a layer above.
This means the backend Protocol simplifies: `send()` accepts a dict and properties, `value()` returns a dict. The Consumer/Producer layer handles dataclass ↔ dict conversion using the shared serializers.
#### 4. Abstract queue naming
Queue names currently use the format `q0/tg/flow/queue-name` or `q2/tg/config/queue-name`, which the Pulsar backend maps to `non-persistent://tg/flow/queue-name` or `persistent://tg/config/queue-name`.
This is an open design question. Options:
**Option A: Simple string names.** Queues are just strings like `"text-completion-request"`. The backend is responsible for mapping to its native format (Pulsar adds `persistent://tg/flow/` prefix, RabbitMQ uses the string as-is or adds a vhost prefix). Persistence and namespace are configuration concerns, not embedded in the name.
**Option B: Structured queue descriptor.** A small object that carries the logical name plus metadata:
```python
@dataclass
class QueueDescriptor:
name: str # e.g. "text-completion-request"
namespace: str = "flow" # logical grouping
persistent: bool = False # must survive broker restart
```
The backend maps this to its native format.
**Option C: Keep the current format** (`q0/tg/flow/name`) but document it as a TrustGraph convention, not a Pulsar convention. Backends parse it.
Option B is the most explicit. Option A is the simplest. Either is workable. The key constraint is that persistence is a property of the queue definition, not a runtime choice — the config push queue is persistent, everything else is not.
#### 5. Implement RabbitMQ backend
Write `rabbitmq_backend.py` implementing the `PubSubBackend` Protocol:
- **`create_producer()`**: Creates a channel and declares the target queue. `send()` publishes to the default exchange with the queue name as routing key. Properties map to AMQP basic properties (specifically `message_id` for the `"id"` property).
- **`create_consumer()`**: Declares the queue and starts consuming with `basic_consume`. Shared subscription is the default RabbitMQ model — multiple consumers on one queue get round-robin delivery. `acknowledge()` maps to `basic_ack`, `negative_acknowledge()` maps to `basic_nack` with `requeue=True`.
- **Persistence**: For persistent queues, declare as durable with `delivery_mode=2` on messages. For non-persistent queues, declare as non-durable.
- **Consumer positioning**: RabbitMQ queues are consumed destructively, so "earliest" doesn't apply in the Pulsar sense. For the config push use case, use a **fanout exchange with per-consumer exclusive queues** — each new processor gets its own queue that receives all config publishes, plus the last-value can be handled by having the config service re-publish on startup.
- **Large messages**: RabbitMQ handles messages up to `rabbit.max_message_size` (default 128MB). No chunking needed.
The factory in `pubsub.py` gets a new branch:
```python
if backend_type == 'rabbitmq':
return RabbitMQBackend(
host=config.get('rabbitmq_host'),
port=config.get('rabbitmq_port'),
username=config.get('rabbitmq_username'),
password=config.get('rabbitmq_password'),
vhost=config.get('rabbitmq_vhost', '/'),
)
```
Backend selection via `PUBSUB_BACKEND=rabbitmq` environment variable or `--pubsub-backend rabbitmq` CLI flag.
#### 6. Clean up remaining Pulsar references
After the above changes, Pulsar-specific code should be confined to:
- `pulsar_backend.py` — the Pulsar implementation
- `pubsub.py` — the factory that imports it
Audit and remove any remaining Pulsar imports, Pulsar exception handling, or Pulsar-specific concepts from:
- `async_processor.py` (currently catches `_pulsar.Interrupted`)
- `consumer.py`, `subscriber.py` (if any Pulsar exceptions leak through)
- Schema files (should be clean already, but verify)
- Gateway service (currently instantiates Pulsar client directly)
The gateway is a special case — it currently bypasses the abstraction layer and creates a Pulsar client directly for dispatching API requests. It should use the same `get_pubsub()` factory as everything else.
### What stays the same
- **Schema definitions**: Already dataclasses. No changes needed.
- **Consumer/Producer/Publisher/Subscriber**: Already backend-agnostic. No changes to their core logic.
- **FlowProcessor and spec wiring**: Already uses `processor.pubsub` to create backend instances. No changes.
- **Backend Protocol**: The interface in `backend.py` is sound. Minor refinement possible (dict vs dataclass at the boundary) but the shape is right.
### Concrete cleanups
The following files have Pulsar-specific imports that should not be there after the abstraction is complete. Pulsar imports should be confined to `pulsar_backend.py` and the factory in `pubsub.py`.
**Dead imports (unused, can just be removed):**
- `trustgraph-base/trustgraph/base/pubsub.py``from pulsar.schema import JsonSchema`, `import pulsar`, `import _pulsar`. The `JsonSchema` import is unused since the switch to `BytesSchema`. The `pulsar`/`_pulsar` imports are only used by the legacy `PulsarClient` class which should be removed (superseded by `PulsarBackend`).
- `trustgraph-base/trustgraph/base/flow_processor.py``from pulsar.schema import JsonSchema`. Unused.
**Legacy `PulsarClient` class:**
- `trustgraph-base/trustgraph/base/pubsub.py` — The `PulsarClient` class is a leftover from before the backend abstraction. `get_pubsub()` still references `PulsarClient.default_pulsar_host` for defaults. Move the defaults to `PulsarBackend` or to environment variable reads in the factory, then delete `PulsarClient`.
**Client libraries using Pulsar directly:**
- `trustgraph-base/trustgraph/clients/base.py``import pulsar`, `import _pulsar`, `from pulsar.schema import JsonSchema`. This is the base class for the old synchronous client library. These clients predate the backend abstraction and use Pulsar directly.
- `trustgraph-base/trustgraph/clients/embeddings_client.py``from pulsar.schema import JsonSchema`, `import _pulsar`.
- `trustgraph-base/trustgraph/clients/*.py` (agent, config, document_embeddings, document_rag, graph_embeddings, graph_rag, llm, prompt, row_embeddings, triples_query) — all import `_pulsar` for exception handling.
These clients are the internal request-response clients used by processors. They need to be migrated to use the backend abstraction or their Pulsar exception handling needs to be wrapped behind a backend-agnostic exception type.
**Translator base class:**
- `trustgraph-base/trustgraph/messaging/translators/base.py``from pulsar.schema import Record`. Used in type hints. Should be removed when `to_pulsar`/`from_pulsar` are renamed.
**Gateway service (bypasses abstraction):**
- `trustgraph-flow/trustgraph/gateway/service.py``import pulsar`. Creates a Pulsar client directly.
- `trustgraph-flow/trustgraph/gateway/config/receiver.py``import pulsar`. Direct Pulsar usage.
The gateway should use `get_pubsub()` like everything else.
**Storage writers:**
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py``import pulsar`
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py``import pulsar`
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py``import pulsar`
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py``import pulsar`
These need investigation — likely Pulsar exception handling or direct client usage that should go through the abstraction.
**Log level:**
- `trustgraph-base/trustgraph/log_level.py``import _pulsar`. Used to set Pulsar's log level. Should be moved into `pulsar_backend.py`.
### Queue naming
The current scheme encodes QoS, tenant, namespace, and queue name into a slash-separated string (`q0/tg/request/config`) which the Pulsar backend parses and maps to a Pulsar URI (`non-persistent://tg/request/config`). This was an attempt at abstraction but it has problems:
- QoS in the name was a mistake — it's a property of the queue definition, not something that belongs in the name. A queue is either persistent or it isn't; that's decided once when the queue is defined.
- The tenant/namespace structure mirrors Pulsar's model. RabbitMQ doesn't use this — it has vhosts and exchange/queue names. Pretending the naming isn't TrustGraph-specific just leaks Pulsar concepts.
- The `topic()` helper generates these strings, and the backend parses them apart. This is unnecessary indirection.
There are two categories of queue in TrustGraph:
**Infrastructure queues** — defined in code, used for system services. These are fixed and well-known:
| Queue | Persistent | Purpose |
|-------|------------|---------|
| `config-request` | No | Config queries |
| `config-response` | No | Config query responses |
| `config-push` | Yes | Config state broadcast |
| `flow-request` | No | Flow management queries |
| `flow-response` | No | Flow management responses |
| `librarian-request` | No | Document storage operations |
| `librarian-response` | No | Document storage responses |
| `knowledge-request` | No | Knowledge graph operations |
| `knowledge-response` | No | Knowledge graph responses |
| `document-embeddings-request` | No | Document vector queries |
| `document-embeddings-response` | No | Document vector responses |
| `row-embeddings-request` | No | Row vector queries |
| `row-embeddings-response` | No | Row vector responses |
| `collection-request` | No | Collection management |
| `collection-response` | No | Collection management responses |
**Flow queues** — defined in configuration, created dynamically per flow. The queue names come from the config service (e.g. `text-completion-request`, `graph-rag-request`, `agent-request`). Each flow instance has its own set of these queues.
For infrastructure queues, the name is just a string. Persistence is a property of the queue definition, not encoded in the name. The backend maps the name to whatever its native format requires.
For flow queues, the name comes from configuration. The config service already distributes queue names as strings — the backend just needs to be able to use them.
#### Proposed scheme: CLASS:TOPICSPACE:TOPIC
A queue name has three parts separated by colons:
- **CLASS** — a small enum that defines the queue's operational characteristics. The backend knows what each class means in terms of persistence, TTL, memory limits, etc. There are only four classes:
| Class | Persistent | TTL | Behaviour |
|-------|------------|-----|-----------|
| `flow` | Yes | Long | Processing pipeline queues. Messages survive broker restart. |
| `request` | No | Short | Transient request-response. Low TTL, no persistence needed — clients retry on failure. |
| `response` | No | Short | Same as request, for the response side. |
| `state` | Yes | Retained | Last-value state broadcast. Consumers need the most recent value on startup, plus any future updates. Config push is the primary example. |
- **TOPICSPACE** — deployment isolation. Keeps different TrustGraph deployments separate when sharing the same pub/sub infrastructure. Most deployments just use `tg`. Avoids the overloaded terms "tenant" and "namespace".
- **TOPIC** — the logical queue identity. What the queue is for.
**Examples:**
```
flow:tg:text-completion-request
flow:tg:graph-rag-request
flow:tg:agent-request
request:tg:librarian
response:tg:librarian
request:tg:config
response:tg:config
state:tg:config
request:tg:flow
response:tg:flow
```
**Backend mapping:**
Each backend parses the three parts and maps them to its native concepts:
- **Pulsar**: `flow:tg:text-completion-request``persistent://tg/flow/text-completion-request`. Class maps to persistent/non-persistent and namespace. State class uses persistent topic with earliest consumer positioning.
- **RabbitMQ**: Topicspace maps to vhost. Class determines queue durability and TTL policy. State class uses a last-value queue (via plugin) or a fanout exchange pattern where each consumer gets the retained state on connect.
- **Kafka**: `flow.tg.text-completion-request` as topic name. Class determines retention and compaction policy. State class maps to a compacted topic (last value per key).
**Why this works:**
- The class enum is small and stable — adding a new class is rare and deliberate
- Queue properties (persistence, TTL) are implied by class, not encoded in the name
- Dynamic registration works naturally — the config service publishes `flow:tg:text-completion-request` and the backend knows how to declare it from the `flow` class
- The colon separator is unambiguous, easy to split, doesn't conflict with URIs or path separators that backends use internally
- No pretence of being generic — this is a TrustGraph convention, and that's fine
### Serialization boundary
**Decision: the backend owns the wire format.**
The contract between the Consumer/Producer layer and the backend is dataclass objects in, dataclass objects out:
- `send()` accepts a dataclass instance and properties dict
- `receive()` returns a message whose `value()` is a dataclass instance
What happens on the wire is the backend's concern. The Pulsar backend uses JSON (via `dataclass_to_dict` / `dict_to_dataclass`). A RabbitMQ backend would likely also use JSON. A future backend could use Protobuf, MessagePack, or Avro if the broker benefits from it.
The serialization helpers stay inside the backend that uses them — they are not shared infrastructure. Each backend brings its own serialization strategy. The Consumer/Producer layer never thinks about wire format.
### Gateway service
**Decision: the gateway uses the backend abstraction like any other component.**
The gateway currently bridges WebSocket/REST to Pulsar directly, bypassing the abstraction layer. It translates incoming API JSON to Pulsar schema objects, sends them, receives responses as Pulsar schema objects, and translates back to API JSON. Since the wire format is JSON in both directions, this is effectively a no-op round trip through the schema machinery.
With the backend abstraction, the gateway follows the same pattern as every other component:
1. Incoming API JSON → translator `decode()` → dataclass
2. Dataclass → backend `send()` (backend handles wire format)
3. Backend `receive()` → dataclass
4. Dataclass → translator `encode()` → API JSON → WebSocket/REST client
This is architecturally simple — one code path, no special cases. The gateway depends on the schema dataclasses and the translator layer, which it already does. The overhead of deserialize-then-reserialize is negligible for the message sizes involved. And it keeps all options open — if a future backend uses a non-JSON wire format, the gateway still works without changes.
## Implementation Order
### Phase 1: Rename translators
Rename `to_pulsar()``decode()`, `from_pulsar()``encode()` across all translator classes and call sites. Remove `from pulsar.schema import Record` from the translator base class. Mechanical find-and-replace, no behavioural changes.
### Phase 2: Queue naming
Replace the `topic()` helper with the CLASS:TOPICSPACE:TOPIC scheme. Update all queue definitions in `schema/services/*.py` and `schema/knowledge/*.py`. Update `PulsarBackend.map_topic()` to parse the new format. Verify all existing functionality still works with Pulsar.
### Phase 3: Clean up Pulsar leaks
Work through the concrete cleanups list: remove dead imports, delete the legacy `PulsarClient` class, migrate the client libraries and gateway to use the backend abstraction. After this phase, `pulsar` imports exist only in `pulsar_backend.py`.
### Phase 4: RabbitMQ backend
Implement `rabbitmq_backend.py` against the existing `PubSubBackend` Protocol. Map queue classes to RabbitMQ concepts: `flow` → durable queues, `request`/`response` → non-durable queues with TTL, `state` → RabbitMQ streams. Add `rabbitmq` as a backend option in the factory. Test end-to-end with `PUBSUB_BACKEND=rabbitmq`.
Phases 1-3 are safe to do on main — they don't change behaviour, just clean up. Phase 4 is additive — it adds a new backend without touching the existing one.
### Config distribution on RabbitMQ
The `state` queue class needs "start from earliest" semantics — a newly started processor must receive the current configuration state.
RabbitMQ Streams (available since 3.9) solve this directly. Streams are persistent, append-only logs that support consumer offset positioning. The RabbitMQ backend maps the `state` class to a stream, and consumers attach with offset `first` to read from the beginning, or `last` to read the most recent entry plus future updates.
Since config pushes are full state snapshots (not deltas), a consumer only needs the most recent entry. The RabbitMQ backend can use `last` offset positioning for `state` class consumers, which delivers the last message in the stream followed by any new messages. This matches the current behaviour where processors read config on startup and then react to updates.

View file

@ -63,7 +63,11 @@ Explainability events stream to client as the query executes:
3. Edges selected with reasoning → event emitted
4. Answer synthesized → event emitted
Client receives `explain_id` and `explain_collection` to fetch full details.
Client receives `explain_id`, `explain_graph`, and `explain_triples` inline
in each explain message. The triples contain the full provenance data for
that step — no follow-up graph query needed. The `explain_id` serves as
the root entity URI within the triples. Data is also written to the
knowledge graph for later audit/analysis.
## URI Structure
@ -144,7 +148,8 @@ class GraphRagResponse:
response: str = ""
end_of_stream: bool = False
explain_id: str | None = None
explain_collection: str | None = None
explain_graph: str | None = None
explain_triples: list[Triple] = field(default_factory=list)
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False
```
@ -154,7 +159,7 @@ class GraphRagResponse:
| message_type | Purpose |
|--------------|---------|
| `chunk` | Response text (streaming or final) |
| `explain` | Explainability event with IRI reference |
| `explain` | Explainability event with inline provenance triples |
### Session Lifecycle

View file

@ -0,0 +1,268 @@
# SPARQL Query Service Technical Specification
## Overview
A pub/sub-hosted SPARQL query service that accepts SPARQL queries, decomposes
them into triple pattern lookups via the existing triples query pub/sub
interface, performs in-memory joins/filters/projections, and returns SPARQL
result bindings.
This makes the triple store queryable using a standard graph query language
without coupling to any specific backend (Neo4j, Cassandra, FalkorDB, etc.).
## Goals
- **SPARQL 1.1 support**: SELECT, ASK, CONSTRUCT, DESCRIBE queries
- **Backend-agnostic**: query via the pub/sub triples interface, not direct
database access
- **Standard service pattern**: FlowProcessor with ConsumerSpec/ProducerSpec,
using TriplesClientSpec to call the triples query service
- **Correct SPARQL semantics**: proper BGP evaluation, joins, OPTIONAL, UNION,
FILTER, BIND, aggregation, solution modifiers (ORDER BY, LIMIT, OFFSET,
DISTINCT)
## Background
The triples query service provides a single-pattern lookup: given optional
(s, p, o) values, return matching triples. This is the equivalent of one
triple pattern in a SPARQL Basic Graph Pattern.
To evaluate a full SPARQL query, we need to:
1. Parse the SPARQL string into an algebra tree
2. Walk the algebra tree, issuing triple pattern lookups for each BGP pattern
3. Join results across patterns (nested-loop or hash join)
4. Apply filters, optionals, unions, and aggregations in-memory
5. Project and return the requested variables
rdflib (already a dependency) provides a SPARQL 1.1 parser and algebra
compiler. We use rdflib to parse queries into algebra trees, then evaluate
the algebra ourselves using the triples query client as the data source.
## Technical Design
### Architecture
```
pub/sub
[Client] ──request──> [SPARQL Query Service] ──triples-request──> [Triples Query Service]
[Client] <─response── [SPARQL Query Service] <─triples-response── [Triples Query Service]
```
The service is a FlowProcessor that:
- Consumes SPARQL query requests
- Uses TriplesClientSpec to issue triple pattern lookups
- Evaluates the SPARQL algebra in-memory
- Produces result responses
### Components
1. **SPARQL Query Service (FlowProcessor)**
- ConsumerSpec for incoming SPARQL requests
- ProducerSpec for outgoing results
- TriplesClientSpec for calling the triples query service
- Delegates parsing and evaluation to the components below
Module: `trustgraph-flow/trustgraph/query/sparql/service.py`
2. **SPARQL Parser (rdflib wrapper)**
- Uses `rdflib.plugins.sparql.prepareQuery` / `parseQuery` and
`rdflib.plugins.sparql.algebra.translateQuery` to produce an algebra tree
- Extracts PREFIX declarations, query type (SELECT/ASK/CONSTRUCT/DESCRIBE),
and the algebra root
Module: `trustgraph-flow/trustgraph/query/sparql/parser.py`
3. **Algebra Evaluator**
- Recursive evaluator over the rdflib algebra tree
- Each algebra node type maps to an evaluation function
- BGP nodes issue triple pattern queries via TriplesClient
- Join/Filter/Optional/Union etc. operate on in-memory solution sequences
Module: `trustgraph-flow/trustgraph/query/sparql/algebra.py`
4. **Solution Sequence**
- A solution is a dict mapping variable names to Term values
- Solution sequences are lists of solutions
- Join: hash join on shared variables
- LeftJoin (OPTIONAL): hash join preserving unmatched left rows
- Union: concatenation
- Filter: evaluate SPARQL expressions against each solution
- Projection/Distinct/Order/Slice: standard post-processing
Module: `trustgraph-flow/trustgraph/query/sparql/solutions.py`
### Data Models
#### Request
```python
@dataclass
class SparqlQueryRequest:
user: str = ""
collection: str = ""
query: str = "" # SPARQL query string
limit: int = 10000 # Safety limit on results
```
#### Response
```python
@dataclass
class SparqlQueryResponse:
error: Error | None = None
query_type: str = "" # "select", "ask", "construct", "describe"
# For SELECT queries
variables: list[str] = field(default_factory=list)
bindings: list[SparqlBinding] = field(default_factory=list)
# For ASK queries
ask_result: bool = False
# For CONSTRUCT/DESCRIBE queries
triples: list[Triple] = field(default_factory=list)
@dataclass
class SparqlBinding:
values: list[Term | None] = field(default_factory=list)
```
### BGP Evaluation Strategy
For each triple pattern in a BGP:
- Extract bound terms (concrete IRIs/literals) and variables
- Call `TriplesClient.query_stream(s, p, o)` with bound terms, None for
variables
- Map returned triples back to variable bindings
For multi-pattern BGPs, join solutions incrementally:
- Order patterns by selectivity (patterns with more bound terms first)
- For each subsequent pattern, substitute bound variables from the current
solution sequence before querying
- This avoids full cross-products and reduces the number of triples queries
### Streaming and Early Termination
The triples query service supports streaming responses (batched delivery via
`TriplesClient.query_stream`). The SPARQL evaluator should use streaming
from the start, not as an optimisation. This is important because:
- **Early termination**: when the SPARQL query has a LIMIT, or when only one
solution is needed (ASK queries), we can stop consuming triples as soon as
we have enough results. Without streaming, a wildcard pattern like
`?s ?p ?o` would fetch the entire graph before we could apply the limit.
- **Memory efficiency**: results are processed batch-by-batch rather than
materialising the full result set in memory before joining.
The batch callback in `query_stream` returns a boolean to signal completion.
The evaluator should signal completion (return True) as soon as sufficient
solutions have been produced, allowing the underlying pub/sub consumer to
stop pulling batches.
### Parallel BGP Execution (Phase 2 Optimisation)
Within a BGP, patterns that share variables benefit from sequential
evaluation with bound-variable substitution (query results from earlier
patterns narrow later queries). However, patterns with no shared variables
are independent and could be issued concurrently via `asyncio.gather`.
A practical approach for a future optimisation pass:
- Analyse BGP patterns and identify connected components (groups of
patterns linked by shared variables)
- Execute independent components in parallel
- Within each component, evaluate patterns sequentially with substitution
This is not needed for correctness -- the sequential approach works for all
cases -- but could significantly reduce latency for queries with independent
pattern groups. Flagged as a phase 2 optimisation.
### FILTER Expression Evaluation
rdflib's algebra represents FILTER expressions as expression trees. We
evaluate these against each solution row, supporting:
- Comparison operators (=, !=, <, >, <=, >=)
- Logical operators (&&, ||, !)
- SPARQL built-in functions (isIRI, isLiteral, isBlank, str, lang,
datatype, bound, regex, etc.)
- Arithmetic operators (+, -, *, /)
## Implementation Order
1. **Schema and service skeleton** -- define SparqlQueryRequest/Response
dataclasses, create the FlowProcessor subclass with ConsumerSpec,
ProducerSpec, and TriplesClientSpec wired up. Verify it starts and
connects.
2. **SPARQL parsing** -- wrap rdflib's parser to produce algebra trees from
SPARQL strings. Handle parse errors gracefully. Unit test with a range of
query shapes.
3. **BGP evaluation** -- implement single-pattern and multi-pattern BGP
evaluation using TriplesClient. This is the core building block. Test
with simple SELECT WHERE { ?s ?p ?o } queries.
4. **Joins and solution sequences** -- implement hash join, left join (for
OPTIONAL), and union. Test with multi-pattern queries.
5. **FILTER evaluation** -- implement the expression evaluator for FILTER
clauses. Start with comparisons and logical operators, then add built-in
functions incrementally.
6. **Solution modifiers** -- DISTINCT, ORDER BY, LIMIT, OFFSET, projection.
7. **ASK / CONSTRUCT / DESCRIBE** -- extend beyond SELECT. ASK is trivial
(non-empty result = true). CONSTRUCT builds triples from a template.
DESCRIBE fetches all triples for matched resources.
8. **Aggregation** -- GROUP BY, HAVING, COUNT, SUM, AVG, MIN, MAX,
GROUP_CONCAT, SAMPLE.
9. **BIND, VALUES, subqueries** -- remaining SPARQL 1.1 features.
10. **API gateway integration** -- add SparqlQueryRequestor dispatcher,
request/response translators, and API endpoint so that the SPARQL
service is accessible via the HTTP gateway.
11. **SDK support** -- add `sparql_query()` method to FlowInstance in the
Python API SDK, following the same pattern as `triples_query()`.
12. **CLI command** -- add a `tg-sparql-query` CLI command that takes a
SPARQL query string (or reads from a file/stdin), submits it via the
SDK, and prints results in a readable format (table for SELECT,
true/false for ASK, Turtle for CONSTRUCT/DESCRIBE).
## Performance Considerations
In-memory join over pub/sub round-trips will be slower than native SPARQL on
a graph database. Key mitigations:
- **Streaming with early termination**: use `query_stream` so that
limit-bound queries don't fetch entire result sets. A `SELECT ... LIMIT 1`
against a wildcard pattern fetches one batch, not the whole graph.
- **Bound-variable substitution**: when evaluating BGP patterns sequentially,
substitute known bindings into subsequent patterns to issue narrow queries
rather than broad ones followed by in-memory filtering.
- **Parallel independent patterns** (phase 2): patterns with no shared
variables can be issued concurrently.
- **Query complexity limits**: may need a cap on the number of triple pattern
queries issued per SPARQL query to prevent runaway evaluation.
### Named Graph Mapping
SPARQL's `GRAPH ?g { ... }` and `GRAPH <uri> { ... }` clauses map to the
triples query service's graph filter parameter:
- `GRAPH <uri> { ?s ?p ?o }` — pass `g=uri` to the triples query
- Patterns outside any GRAPH clause — pass `g=""` (default graph only)
- `GRAPH ?g { ?s ?p ?o }` — pass `g="*"` (all graphs), then bind `?g` from
the returned triple's graph field
The triples query interface does not support a wildcard graph natively in
the SPARQL sense, but `g="*"` (all graphs) combined with client-side
filtering on the returned graph values achieves the same effect.
## Open Questions
- **SPARQL 1.2**: rdflib's parser support for 1.2 features (property paths
are already in 1.1; 1.2 adds lateral joins, ADJUST, etc.). Start with
1.1 and extend as rdflib support matures.

File diff suppressed because one or more lines are too long

View file

@ -10,6 +10,7 @@ properties:
- observation
- answer
- final-answer
- explain
- error
example: answer
content:
@ -29,6 +30,11 @@ properties:
type: string
description: Named graph containing the explainability data
example: urn:graph:retrieval
explain_triples:
type: array
description: Provenance triples for this explain event (inline, no follow-up query needed)
items:
$ref: '../common/Triple.yaml'
end-of-message:
type: boolean
description: Current chunk type is complete (streaming mode)

View file

@ -3,6 +3,9 @@ description: |
Librarian service request for document library management.
Operations: add-document, remove-document, list-documents,
get-document-metadata, stream-document, add-child-document,
list-children, begin-upload, upload-chunk, complete-upload,
abort-upload, get-upload-status, list-uploads,
start-processing, stop-processing, list-processing
required:
- operation
@ -13,6 +16,17 @@ properties:
- add-document
- remove-document
- list-documents
- get-document-metadata
- get-document-content
- stream-document
- add-child-document
- list-children
- begin-upload
- upload-chunk
- complete-upload
- abort-upload
- get-upload-status
- list-uploads
- start-processing
- stop-processing
- list-processing
@ -21,6 +35,21 @@ properties:
- `add-document`: Add document to library
- `remove-document`: Remove document from library
- `list-documents`: List documents in library
- `get-document-metadata`: Get document metadata
- `get-document-content`: Get full document content in a single response.
**Deprecated** — use `stream-document` instead. Fails for documents
exceeding the broker's max message size.
- `stream-document`: Stream document content in chunks. Each response
includes `chunk_index` and `is_final`. Preferred over `get-document-content`
for all document sizes.
- `add-child-document`: Add a child document (e.g. page, chunk)
- `list-children`: List child documents of a parent
- `begin-upload`: Start a chunked upload session
- `upload-chunk`: Upload a chunk of data
- `complete-upload`: Finalize a chunked upload
- `abort-upload`: Cancel a chunked upload
- `get-upload-status`: Check upload progress
- `list-uploads`: List active upload sessions
- `start-processing`: Start processing library documents
- `stop-processing`: Stop library processing
- `list-processing`: List processing status

View file

@ -8,8 +8,7 @@ required:
properties:
text:
type: string
description: Text content (base64 encoded)
format: byte
description: Text content, either raw text or base64 encoded for compatibility with older clients
example: VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg==
id:
type: string

View file

@ -18,6 +18,11 @@ properties:
type: string
description: Named graph containing the explainability data
example: urn:graph:retrieval
explain_triples:
type: array
description: Provenance triples for this explain event (inline, no follow-up query needed)
items:
$ref: '../common/Triple.yaml'
end-of-stream:
type: boolean
description: Indicates LLM response stream is complete

View file

@ -18,6 +18,11 @@ properties:
type: string
description: Named graph containing the explainability data
example: urn:graph:retrieval
explain_triples:
type: array
description: Provenance triples for this explain event (inline, no follow-up query needed)
items:
$ref: '../common/Triple.yaml'
end_of_stream:
type: boolean
description: Indicates LLM response stream is complete

View file

@ -2,7 +2,7 @@ openapi: 3.1.0
info:
title: TrustGraph API Gateway
version: "2.1"
version: "2.2"
description: |
REST API for TrustGraph - an AI-powered knowledge graph and RAG system.
@ -28,7 +28,7 @@ info:
Require running flow instance, accessed via `/api/v1/flow/{flow}/service/{kind}`:
- AI services: agent, text-completion, prompt, RAG (document/graph)
- Embeddings: embeddings, graph-embeddings, document-embeddings
- Query: triples, rows, nlp-query, structured-query, row-embeddings
- Query: triples, rows, nlp-query, structured-query, sparql-query, row-embeddings
- Data loading: text-load, document-load
- Utilities: mcp-tool, structured-diag
@ -139,6 +139,8 @@ paths:
$ref: './paths/flow/text-load.yaml'
/api/v1/flow/{flow}/service/document-load:
$ref: './paths/flow/document-load.yaml'
/api/v1/flow/{flow}/service/sparql-query:
$ref: './paths/flow/sparql-query.yaml'
# Document streaming
/api/v1/document-stream:

View file

@ -29,6 +29,7 @@ post:
- `action`: Action being taken
- `observation`: Result from action
- `answer`: Final response to user
- `explain`: Provenance event with inline triples (`explain_triples`)
- `error`: Error occurred
Each chunk may have multiple messages. Check flags:
@ -116,6 +117,22 @@ post:
content: ""
end-of-message: true
end-of-dialog: true
explainEvent:
summary: Explain event with inline provenance triples
value:
chunk-type: explain
content: ""
explain_id: urn:trustgraph:agent:abc123
explain_graph: urn:graph:retrieval
explain_triples:
- s: {t: i, i: "urn:trustgraph:agent:abc123"}
p: {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
o: {t: i, i: "https://trustgraph.ai/ns/AgentSession"}
- s: {t: i, i: "urn:trustgraph:agent:abc123"}
p: {t: i, i: "https://trustgraph.ai/ns/query"}
o: {t: l, v: "Explain quantum computing"}
end-of-message: true
end-of-dialog: false
legacyResponse:
summary: Legacy non-streaming response
value:

View file

@ -24,8 +24,13 @@ post:
## Streaming
Enable `streaming: true` to receive the answer as it's generated:
- Multiple messages with `response` content
- Multiple `chunk` messages with `response` content
- `explain` messages with inline provenance triples (`explain_triples`)
- Final message with `end-of-stream: true`
- Session ends with `end_of_session: true`
Explain events carry `explain_id`, `explain_graph`, and `explain_triples`
inline in the stream, so no follow-up knowledge graph query is needed.
Without streaming, returns complete answer in single response.
@ -96,6 +101,21 @@ post:
value:
response: "The research papers present three"
end-of-stream: false
explainEvent:
summary: Explain event with inline provenance triples
value:
message_type: explain
explain_id: urn:trustgraph:question:abc123
explain_graph: urn:graph:retrieval
explain_triples:
- s: {t: i, i: "urn:trustgraph:question:abc123"}
p: {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
o: {t: i, i: "https://trustgraph.ai/ns/DocumentRagQuestion"}
- s: {t: i, i: "urn:trustgraph:question:abc123"}
p: {t: i, i: "https://trustgraph.ai/ns/query"}
o: {t: l, v: "What are the key findings in the research papers?"}
end-of-stream: false
end_of_session: false
streamingComplete:
summary: Streaming complete marker
value:

View file

@ -25,8 +25,13 @@ post:
## Streaming
Enable `streaming: true` to receive the answer as it's generated:
- Multiple messages with `response` content
- Multiple `chunk` messages with `response` content
- `explain` messages with inline provenance triples (`explain_triples`)
- Final message with `end-of-stream: true`
- Session ends with `end_of_session: true`
Explain events carry `explain_id`, `explain_graph`, and `explain_triples`
inline in the stream, so no follow-up knowledge graph query is needed.
Without streaming, returns complete answer in single response.
@ -116,6 +121,21 @@ post:
value:
response: "Quantum physics and computer science intersect"
end-of-stream: false
explainEvent:
summary: Explain event with inline provenance triples
value:
message_type: explain
explain_id: urn:trustgraph:question:abc123
explain_graph: urn:graph:retrieval
explain_triples:
- s: {t: i, i: "urn:trustgraph:question:abc123"}
p: {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
o: {t: i, i: "https://trustgraph.ai/ns/GraphRagQuestion"}
- s: {t: i, i: "urn:trustgraph:question:abc123"}
p: {t: i, i: "https://trustgraph.ai/ns/query"}
o: {t: l, v: "What connections exist between quantum physics and computer science?"}
end_of_stream: false
end_of_session: false
streamingComplete:
summary: Streaming complete marker
value:

View file

@ -0,0 +1,145 @@
post:
tags:
- Flow Services
summary: SPARQL query - execute SPARQL 1.1 queries against the knowledge graph
description: |
Execute a SPARQL 1.1 query against the knowledge graph.
## Supported Query Types
- **SELECT**: Returns variable bindings as a table of results
- **ASK**: Returns true/false for existence checks
- **CONSTRUCT**: Returns a set of triples built from a template
- **DESCRIBE**: Returns triples describing matched resources
## SPARQL Features
Supports standard SPARQL 1.1 features including:
- Basic Graph Patterns (BGPs) with triple pattern matching
- OPTIONAL, UNION, FILTER
- BIND, VALUES
- ORDER BY, LIMIT, OFFSET, DISTINCT
- GROUP BY with aggregates (COUNT, SUM, AVG, MIN, MAX, GROUP_CONCAT)
- Built-in functions (isIRI, STR, REGEX, CONTAINS, etc.)
## Query Examples
Find all entities of a type:
```sparql
SELECT ?s ?label WHERE {
?s <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://example.com/Person> .
?s <http://www.w3.org/2000/01/rdf-schema#label> ?label .
}
LIMIT 10
```
Check if an entity exists:
```sparql
ASK { <http://example.com/alice> ?p ?o }
```
operationId: sparqlQueryService
security:
- bearerAuth: []
parameters:
- name: flow
in: path
required: true
schema:
type: string
description: Flow instance ID
example: my-flow
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- query
properties:
query:
type: string
description: SPARQL 1.1 query string
user:
type: string
default: trustgraph
description: User/keyspace identifier
collection:
type: string
default: default
description: Collection identifier
limit:
type: integer
default: 10000
description: Safety limit on number of results
examples:
selectQuery:
summary: SELECT query
value:
query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"
user: trustgraph
collection: default
askQuery:
summary: ASK query
value:
query: "ASK { <http://example.com/alice> ?p ?o }"
responses:
'200':
description: Successful response
content:
application/json:
schema:
type: object
properties:
query-type:
type: string
enum: [select, ask, construct, describe]
variables:
type: array
items:
type: string
description: Variable names (SELECT only)
bindings:
type: array
items:
type: object
properties:
values:
type: array
items:
$ref: '../../components/schemas/common/RdfValue.yaml'
description: Result rows (SELECT only)
ask-result:
type: boolean
description: Boolean result (ASK only)
triples:
type: array
description: Result triples (CONSTRUCT/DESCRIBE only)
error:
type: object
properties:
type:
type: string
message:
type: string
examples:
selectResult:
summary: SELECT result
value:
query-type: select
variables: [s, p, o]
bindings:
- values:
- {t: i, i: "http://example.com/alice"}
- {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"}
- {t: i, i: "http://example.com/Person"}
askResult:
summary: ASK result
value:
query-type: ask
ask-result: true
'401':
$ref: '../../components/responses/Unauthorized.yaml'
'500':
$ref: '../../components/responses/Error.yaml'

View file

@ -8,7 +8,7 @@ post:
## Text Load Overview
Fire-and-forget document loading:
- **Input**: Text content (base64 encoded)
- **Input**: Text content (raw UTF-8 or base64 encoded)
- **Process**: Chunk, embed, store
- **Output**: None (202 Accepted)
@ -26,7 +26,14 @@ post:
## Text Format
Text must be base64 encoded:
Text may be sent as raw UTF-8 text:
```
{
"text": "Cancer survival: 2.74× higher hazard ratio"
}
```
Older clients may still send base64 encoded text:
```
text_content = "This is the document..."
encoded = base64.b64encode(text_content.encode('utf-8'))
@ -78,12 +85,12 @@ post:
simpleLoad:
summary: Load text document
value:
text: VGhpcyBpcyB0aGUgZG9jdW1lbnQgdGV4dC4uLg==
text: This is the document text...
id: doc-123
user: alice
collection: research
withMetadata:
summary: Load with RDF metadata
summary: Load with RDF metadata using base64 text
value:
text: UXVhbnR1bSBjb21wdXRpbmcgdXNlcyBxdWFudHVtIG1lY2hhbmljcyBwcmluY2lwbGVzLi4u
id: doc-456

View file

@ -2,7 +2,7 @@ asyncapi: 3.0.0
info:
title: TrustGraph WebSocket API
version: "2.1"
version: "2.2"
description: |
WebSocket API for TrustGraph - providing multiplexed, asynchronous access to all services.
@ -31,7 +31,7 @@ info:
**Flow-Hosted Services** (require `flow` parameter):
- agent, text-completion, prompt, document-rag, graph-rag
- embeddings, graph-embeddings, document-embeddings
- triples, rows, nlp-query, structured-query, structured-diag, row-embeddings
- triples, rows, nlp-query, structured-query, sparql-query, structured-diag, row-embeddings
- text-load, document-load, mcp-tool
## Schema Reuse

View file

@ -34,6 +34,7 @@ payload:
- $ref: './requests/RowEmbeddingsRequest.yaml'
- $ref: './requests/TextLoadRequest.yaml'
- $ref: './requests/DocumentLoadRequest.yaml'
- $ref: './requests/SparqlQueryRequest.yaml'
examples:
- name: Config service request

View file

@ -0,0 +1,46 @@
type: object
description: WebSocket request for sparql-query service (flow-hosted service)
required:
- id
- service
- flow
- request
properties:
id:
type: string
description: Unique request identifier
service:
type: string
const: sparql-query
description: Service identifier for sparql-query service
flow:
type: string
description: Flow ID
request:
type: object
required:
- query
properties:
query:
type: string
description: SPARQL 1.1 query string
user:
type: string
default: trustgraph
description: User/keyspace identifier
collection:
type: string
default: default
description: Collection identifier
limit:
type: integer
default: 10000
description: Safety limit on number of results
examples:
- id: req-1
service: sparql-query
flow: my-flow
request:
query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"
user: trustgraph
collection: default

View file

@ -87,10 +87,11 @@ def sample_message_data():
"history": []
},
"AgentResponse": {
"answer": "Machine learning is a subset of AI.",
"chunk_type": "answer",
"content": "Machine learning is a subset of AI.",
"end_of_message": True,
"end_of_dialog": True,
"error": None,
"thought": "I need to provide information about machine learning.",
"observation": None
},
"Metadata": {
"id": "test-doc-123",

View file

@ -38,7 +38,7 @@ class TestDocumentEmbeddingsRequestContract:
assert request.user == "test_user"
assert request.collection == "test_collection"
def test_request_translator_to_pulsar(self):
def test_request_translator_decode(self):
"""Test request translator converts dict to Pulsar schema"""
translator = DocumentEmbeddingsRequestTranslator()
@ -49,7 +49,7 @@ class TestDocumentEmbeddingsRequestContract:
"collection": "custom_collection"
}
result = translator.to_pulsar(data)
result = translator.decode(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2, 0.3, 0.4]
@ -57,7 +57,7 @@ class TestDocumentEmbeddingsRequestContract:
assert result.user == "custom_user"
assert result.collection == "custom_collection"
def test_request_translator_to_pulsar_with_defaults(self):
def test_request_translator_decode_with_defaults(self):
"""Test request translator uses correct defaults"""
translator = DocumentEmbeddingsRequestTranslator()
@ -66,7 +66,7 @@ class TestDocumentEmbeddingsRequestContract:
# No limit, user, or collection provided
}
result = translator.to_pulsar(data)
result = translator.decode(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vector == [0.1, 0.2]
@ -74,7 +74,7 @@ class TestDocumentEmbeddingsRequestContract:
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default
def test_request_translator_from_pulsar(self):
def test_request_translator_encode(self):
"""Test request translator converts Pulsar schema to dict"""
translator = DocumentEmbeddingsRequestTranslator()
@ -85,7 +85,7 @@ class TestDocumentEmbeddingsRequestContract:
collection="test_collection"
)
result = translator.from_pulsar(request)
result = translator.encode(request)
assert isinstance(result, dict)
assert result["vector"] == [0.5, 0.6]
@ -134,7 +134,7 @@ class TestDocumentEmbeddingsResponseContract:
assert response.error == error
assert response.chunks == []
def test_response_translator_from_pulsar_with_chunks(self):
def test_response_translator_encode_with_chunks(self):
"""Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator()
@ -147,7 +147,7 @@ class TestDocumentEmbeddingsResponseContract:
]
)
result = translator.from_pulsar(response)
result = translator.encode(response)
assert isinstance(result, dict)
assert "chunks" in result
@ -155,7 +155,7 @@ class TestDocumentEmbeddingsResponseContract:
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
assert result["chunks"][0]["score"] == 0.95
def test_response_translator_from_pulsar_with_empty_chunks(self):
def test_response_translator_encode_with_empty_chunks(self):
"""Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator()
@ -164,25 +164,25 @@ class TestDocumentEmbeddingsResponseContract:
chunks=[]
)
result = translator.from_pulsar(response)
result = translator.encode(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunks(self):
def test_response_translator_encode_with_none_chunks(self):
"""Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = None
result = translator.from_pulsar(response)
result = translator.encode(response)
assert isinstance(result, dict)
assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self):
def test_response_translator_encode_with_completion(self):
"""Test response translator with completion flag"""
translator = DocumentEmbeddingsResponseTranslator()
@ -194,7 +194,7 @@ class TestDocumentEmbeddingsResponseContract:
]
)
result, is_final = translator.from_response_with_completion(response)
result, is_final = translator.encode_with_completion(response)
assert isinstance(result, dict)
assert "chunks" in result
@ -202,12 +202,12 @@ class TestDocumentEmbeddingsResponseContract:
assert result["chunks"][0]["chunk_id"] == "chunk1"
assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self):
"""Test that to_pulsar raises NotImplementedError for responses"""
def test_response_translator_decode_not_implemented(self):
"""Test that decode raises NotImplementedError for responses"""
translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
translator.decode({"chunks": [{"chunk_id": "test", "score": 0.9}]})
class TestDocumentEmbeddingsMessageCompatibility:
@ -225,7 +225,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Convert to Pulsar request
req_translator = DocumentEmbeddingsRequestTranslator()
pulsar_request = req_translator.to_pulsar(request_data)
pulsar_request = req_translator.decode(request_data)
# Simulate service processing and creating response
response = DocumentEmbeddingsResponse(
@ -238,7 +238,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Convert response back to dict
resp_translator = DocumentEmbeddingsResponseTranslator()
response_data = resp_translator.from_pulsar(response)
response_data = resp_translator.encode(response)
# Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
@ -261,7 +261,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Convert response to dict
translator = DocumentEmbeddingsResponseTranslator()
response_data = translator.from_pulsar(response)
response_data = translator.encode(response)
# Verify error handling
assert isinstance(response_data, dict)

View file

@ -212,10 +212,11 @@ class TestAgentMessageContracts:
# Test required fields
response = AgentResponse(**response_data)
assert hasattr(response, 'answer')
assert hasattr(response, 'chunk_type')
assert hasattr(response, 'content')
assert hasattr(response, 'end_of_message')
assert hasattr(response, 'end_of_dialog')
assert hasattr(response, 'error')
assert hasattr(response, 'thought')
assert hasattr(response, 'observation')
def test_agent_step_schema_contract(self):
"""Test AgentStep schema contract"""

View file

@ -0,0 +1,177 @@
"""
Contract tests for orchestrator message schemas.
Verifies that AgentRequest/AgentStep with orchestration fields
serialise and deserialise correctly through the Pulsar schema layer.
"""
import pytest
import json
from trustgraph.schema import AgentRequest, AgentStep, PlanStep
@pytest.mark.contract
class TestOrchestrationFieldContracts:
"""Contract tests for orchestration fields on AgentRequest."""
def test_agent_request_orchestration_fields_roundtrip(self):
req = AgentRequest(
question="Test question",
user="testuser",
collection="default",
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
pattern="react",
task_type="research",
framing="Focus on accuracy",
conversation_id="conv-456",
)
assert req.correlation_id == "corr-123"
assert req.parent_session_id == "parent-sess"
assert req.subagent_goal == "What is X?"
assert req.expected_siblings == 4
assert req.pattern == "react"
assert req.task_type == "research"
assert req.framing == "Focus on accuracy"
assert req.conversation_id == "conv-456"
def test_agent_request_orchestration_fields_default_empty(self):
req = AgentRequest(
question="Test question",
user="testuser",
)
assert req.correlation_id == ""
assert req.parent_session_id == ""
assert req.subagent_goal == ""
assert req.expected_siblings == 0
assert req.pattern == ""
assert req.task_type == ""
assert req.framing == ""
@pytest.mark.contract
class TestSubagentCompletionStepContract:
"""Contract tests for subagent-completion step type."""
def test_subagent_completion_step_fields(self):
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation="The answer text",
step_type="subagent-completion",
)
assert step.step_type == "subagent-completion"
assert step.observation == "The answer text"
assert step.thought == "Subagent completed"
assert step.action == "complete"
def test_subagent_completion_in_request_history(self):
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation="answer",
step_type="subagent-completion",
)
req = AgentRequest(
question="goal",
user="testuser",
correlation_id="corr-123",
history=[step],
)
assert len(req.history) == 1
assert req.history[0].step_type == "subagent-completion"
assert req.history[0].observation == "answer"
@pytest.mark.contract
class TestSynthesisStepContract:
"""Contract tests for synthesis step type with subagent_results."""
def test_synthesis_step_with_results(self):
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
step = AgentStep(
thought="All subagents completed",
action="aggregate",
arguments={},
observation=json.dumps(results),
step_type="synthesise",
subagent_results=results,
)
assert step.step_type == "synthesise"
assert step.subagent_results == results
assert json.loads(step.observation) == results
def test_synthesis_request_matches_supervisor_expectations(self):
"""The synthesis request built by the aggregator must be
recognisable by SupervisorPattern._synthesise()."""
results = {"goal-a": "answer-a", "goal-b": "answer-b"}
step = AgentStep(
thought="All subagents completed",
action="aggregate",
arguments={},
observation=json.dumps(results),
step_type="synthesise",
subagent_results=results,
)
req = AgentRequest(
question="Original question",
user="testuser",
pattern="supervisor",
correlation_id="",
session_id="parent-sess",
history=[step],
)
# SupervisorPattern checks for step_type='synthesise' with
# subagent_results
has_results = bool(
req.history
and any(
getattr(h, 'step_type', '') == 'synthesise'
and getattr(h, 'subagent_results', None)
for h in req.history
)
)
assert has_results
# Pattern must be supervisor
assert req.pattern == "supervisor"
# Correlation ID must be empty (not re-intercepted)
assert req.correlation_id == ""
@pytest.mark.contract
class TestPlanStepContract:
"""Contract tests for plan steps in history."""
def test_plan_step_in_history(self):
plan = [
PlanStep(goal="Step 1", tool_hint="knowledge-query",
depends_on=[], status="completed", result="done"),
PlanStep(goal="Step 2", tool_hint="",
depends_on=[0], status="pending", result=""),
]
step = AgentStep(
thought="Created plan",
action="plan",
step_type="plan",
plan=plan,
)
assert step.step_type == "plan"
assert len(step.plan) == 2
assert step.plan[0].goal == "Step 1"
assert step.plan[0].status == "completed"
assert step.plan[1].depends_on == [0]

View file

@ -0,0 +1,129 @@
"""
Contract tests for provenance triple wire format verifies that triples
built by the provenance library can be parsed by the explainability API
through the wire format conversion.
"""
import pytest
from trustgraph.schema import IRI, LITERAL
from trustgraph.provenance import (
agent_decomposition_triples,
agent_finding_triples,
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
)
from trustgraph.api.explainability import (
ExplainEntity,
Decomposition,
Finding,
Plan,
StepResult,
Synthesis,
wire_triples_to_tuples,
)
def _triples_to_wire(triples):
"""Convert provenance Triple objects to the wire format dicts
that the gateway/socket client would produce."""
wire = []
for t in triples:
entry = {
"s": _term_to_wire(t.s),
"p": _term_to_wire(t.p),
"o": _term_to_wire(t.o),
}
wire.append(entry)
return wire
def _term_to_wire(term):
"""Convert a Term to wire format dict."""
if term.type == IRI:
return {"t": "i", "i": term.iri}
elif term.type == LITERAL:
return {"t": "l", "v": term.value}
return {"t": "l", "v": str(term)}
def _roundtrip(triples, uri):
"""Convert triples through wire format and parse via from_triples."""
wire = _triples_to_wire(triples)
tuples = wire_triples_to_tuples(wire)
return ExplainEntity.from_triples(uri, tuples)
@pytest.mark.contract
class TestDecompositionWireFormat:
def test_roundtrip(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session",
["What is X?", "What is Y?"],
)
entity = _roundtrip(triples, "urn:decompose")
assert isinstance(entity, Decomposition)
assert set(entity.goals) == {"What is X?", "What is Y?"}
@pytest.mark.contract
class TestFindingWireFormat:
def test_roundtrip(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
document_id="urn:doc/finding",
)
entity = _roundtrip(triples, "urn:finding")
assert isinstance(entity, Finding)
assert entity.goal == "What is X?"
assert entity.document == "urn:doc/finding"
@pytest.mark.contract
class TestPlanWireFormat:
def test_roundtrip(self):
triples = agent_plan_triples(
"urn:plan", "urn:session",
["Step 1", "Step 2", "Step 3"],
)
entity = _roundtrip(triples, "urn:plan")
assert isinstance(entity, Plan)
assert set(entity.steps) == {"Step 1", "Step 2", "Step 3"}
@pytest.mark.contract
class TestStepResultWireFormat:
def test_roundtrip(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
document_id="urn:doc/step",
)
entity = _roundtrip(triples, "urn:step")
assert isinstance(entity, StepResult)
assert entity.step == "Define X"
assert entity.document == "urn:doc/step"
@pytest.mark.contract
class TestSynthesisWireFormat:
def test_roundtrip(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
document_id="urn:doc/synthesis",
)
entity = _roundtrip(triples, "urn:synthesis")
assert isinstance(entity, Synthesis)
assert entity.document == "urn:doc/synthesis"

View file

@ -33,7 +33,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_session=True"
@ -57,7 +57,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_session=False"
@ -80,7 +80,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False
@ -103,7 +103,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
@ -125,7 +125,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_session=True"
@ -147,7 +147,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
@ -168,7 +168,7 @@ class TestRAGTranslatorCompletionFlags:
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_stream=False"
@ -188,20 +188,18 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
answer="4",
error=None,
thought=None,
observation=None,
chunk_type="answer",
content="4",
end_of_message=True,
end_of_dialog=True
end_of_dialog=True,
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_dialog=True"
assert response_dict["answer"] == "4"
assert response_dict["content"] == "4"
assert response_dict["end_of_dialog"] is True
def test_agent_translator_is_final_with_end_of_dialog_false(self):
@ -212,44 +210,20 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
answer=None,
error=None,
thought="I need to solve this.",
observation=None,
chunk_type="thought",
content="I need to solve this.",
end_of_message=True,
end_of_dialog=False
end_of_dialog=False,
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_dialog=False"
assert response_dict["thought"] == "I need to solve this."
assert response_dict["content"] == "I need to solve this."
assert response_dict["end_of_dialog"] is False
def test_agent_translator_is_final_fallback_with_answer(self):
"""
Test that AgentResponseTranslator returns is_final=True
when answer is present (fallback for legacy responses).
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
# Legacy response without end_of_dialog flag
response = AgentResponse(
answer="4",
error=None,
thought=None,
observation=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when answer is present (legacy fallback)"
assert response_dict["answer"] == "4"
def test_agent_translator_intermediate_message_is_not_final(self):
"""
Test that intermediate messages (thought/observation) return is_final=False.
@ -259,32 +233,28 @@ class TestAgentTranslatorCompletionFlags:
# Test thought message
thought_response = AgentResponse(
answer=None,
error=None,
thought="Processing...",
observation=None,
chunk_type="thought",
content="Processing...",
end_of_message=True,
end_of_dialog=False
end_of_dialog=False,
)
# Act
thought_dict, thought_is_final = translator.from_response_with_completion(thought_response)
thought_dict, thought_is_final = translator.encode_with_completion(thought_response)
# Assert
assert thought_is_final is False, "Thought message must not be final"
# Test observation message
observation_response = AgentResponse(
answer=None,
error=None,
thought=None,
observation="Result found",
chunk_type="observation",
content="Result found",
end_of_message=True,
end_of_dialog=False
end_of_dialog=False,
)
# Act
obs_dict, obs_is_final = translator.from_response_with_completion(observation_response)
obs_dict, obs_is_final = translator.encode_with_completion(observation_response)
# Assert
assert obs_is_final is False, "Observation message must not be final"
@ -302,14 +272,10 @@ class TestAgentTranslatorCompletionFlags:
content="",
end_of_message=True,
end_of_dialog=True,
answer=None,
error=None,
thought=None,
observation=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
response_dict, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True, "Streaming format must use end_of_dialog for is_final"

View file

@ -9,7 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, ANY, patch
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
@ -187,7 +187,7 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
# Verify tool was executed
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="default", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
@ -272,7 +272,7 @@ Args: {{
# Verify correct service was called
if tool_name == "knowledge_query":
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default")
mock_flow_context("graph-rag-request").rag.assert_called_with("test question", collection="default", explain_callback=ANY, parent_uri=ANY)
elif tool_name == "text_completion":
mock_flow_context("prompt-request").question.assert_called()
@ -726,7 +726,7 @@ Final Answer: {
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default")
graph_rag_client.rag.assert_called_once_with("What is AI?", collection="default", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_knowledge_query_with_custom_collection(self, mock_flow_context):
@ -739,7 +739,7 @@ Final Answer: {
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection")
graph_rag_client.rag.assert_called_once_with("What is machine learning?", collection="custom_collection", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_knowledge_query_with_none_collection(self, mock_flow_context):
@ -752,7 +752,7 @@ Final Answer: {
# Assert
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default")
graph_rag_client.rag.assert_called_once_with("Explain neural networks", collection="default", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context):
@ -810,7 +810,7 @@ Args: {
# Verify the custom collection was used
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers")
graph_rag_client.rag.assert_called_once_with("Latest AI research?", collection="research_papers", explain_callback=ANY, parent_uri=ANY)
@pytest.mark.asyncio
async def test_knowledge_query_multiple_collections(self, mock_flow_context):
@ -840,4 +840,4 @@ Args: {
# Verify correct collection was used
graph_rag_client = mock_flow_context("graph-rag-request")
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection)
graph_rag_client.rag.assert_called_once_with(question, collection=expected_collection, explain_callback=ANY, parent_uri=ANY)

View file

@ -39,7 +39,7 @@ class TestAgentServiceNonStreaming:
mock_agent_manager_class.return_value = mock_agent_instance
# Mock react to call think and observe callbacks
async def mock_react(question, history, think, observe, answer, context, streaming):
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
await think("I need to solve this.", is_final=True)
await observe("The answer is 4.", is_final=True)
return Final(thought="Final answer", final="4")
@ -76,22 +76,33 @@ class TestAgentServiceNonStreaming:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: should have 3 responses (thought, observation, answer)
assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}"
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
]
# Should have explain events for session, iteration, observation, and final
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
# Should have 3 content responses (thought, observation, answer)
assert len(content_responses) == 3, f"Expected 3 content responses, got {len(content_responses)}"
# Check thought message
thought_response = sent_responses[0]
thought_response = content_responses[0]
assert isinstance(thought_response, AgentResponse)
assert thought_response.thought == "I need to solve this."
assert thought_response.answer is None
assert thought_response.chunk_type == "thought"
assert thought_response.content == "I need to solve this."
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
# Check observation message
observation_response = sent_responses[1]
observation_response = content_responses[1]
assert isinstance(observation_response, AgentResponse)
assert observation_response.observation == "The answer is 4."
assert observation_response.answer is None
assert observation_response.chunk_type == "observation"
assert observation_response.content == "The answer is 4."
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@ -120,7 +131,7 @@ class TestAgentServiceNonStreaming:
mock_agent_manager_class.return_value = mock_agent_instance
# Mock react to return Final directly
async def mock_react(question, history, think, observe, answer, context, streaming):
async def mock_react(question, history, think, observe, answer, context, streaming, on_action=None):
return Final(thought="Final answer", final="4")
mock_agent_instance.react = mock_react
@ -155,15 +166,25 @@ class TestAgentServiceNonStreaming:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: should have 1 response (final answer)
assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}"
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
]
# Should have explain events for session and final
assert len(explain_responses) >= 1, "Expected at least 1 explain event"
# Should have 1 content response (final answer)
assert len(content_responses) == 1, f"Expected 1 content response, got {len(content_responses)}"
# Check final answer message
answer_response = sent_responses[0]
answer_response = content_responses[0]
assert isinstance(answer_response, AgentResponse)
assert answer_response.answer == "4"
assert answer_response.thought is None
assert answer_response.observation is None
assert answer_response.chunk_type == "answer"
assert answer_response.content == "4"
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"

View file

@ -0,0 +1,216 @@
"""
Unit tests for the Aggregator tracks fan-out correlations and triggers
synthesis when all subagents complete.
"""
import time
import pytest
from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(question="Test question", user="testuser",
collection="default", streaming=False,
session_id="parent-session", task_type="research",
framing="test framing", conversation_id="conv-1"):
return AgentRequest(
question=question,
user=user,
collection=collection,
streaming=streaming,
session_id=session_id,
task_type=task_type,
framing=framing,
conversation_id=conversation_id,
)
class TestRegisterFanout:
def test_stores_correlation_entry(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 3)
assert "corr-1" in agg.correlations
entry = agg.correlations["corr-1"]
assert entry["parent_session_id"] == "parent-1"
assert entry["expected"] == 3
assert entry["results"] == {}
def test_stores_request_template(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
entry = agg.correlations["corr-1"]
assert entry["request_template"] is template
def test_records_creation_time(self):
agg = Aggregator()
before = time.time()
agg.register_fanout("corr-1", "parent-1", 2)
after = time.time()
created = agg.correlations["corr-1"]["created_at"]
assert before <= created <= after
class TestRecordCompletion:
def test_returns_false_until_all_done(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 3)
assert agg.record_completion("corr-1", "goal-a", "answer-a") is False
assert agg.record_completion("corr-1", "goal-b", "answer-b") is False
assert agg.record_completion("corr-1", "goal-c", "answer-c") is True
def test_returns_none_for_unknown_correlation(self):
agg = Aggregator()
result = agg.record_completion("unknown", "goal", "answer")
assert result is None
def test_stores_results_by_goal(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 2)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
results = agg.correlations["corr-1"]["results"]
assert results["goal-a"] == "answer-a"
assert results["goal-b"] == "answer-b"
def test_single_subagent(self):
agg = Aggregator()
agg.register_fanout("corr-1", "parent-1", 1)
assert agg.record_completion("corr-1", "goal-a", "answer") is True
class TestGetOriginalRequest:
def test_peeks_without_consuming(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
result = agg.get_original_request("corr-1")
assert result is template
# Entry still exists
assert "corr-1" in agg.correlations
def test_returns_none_for_unknown(self):
agg = Aggregator()
assert agg.get_original_request("unknown") is None
class TestBuildSynthesisRequest:
def test_builds_correct_request(self):
agg = Aggregator()
template = _make_request(
question="Original question",
streaming=True,
task_type="risk-assessment",
framing="Assess risks",
)
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
assert req.question == "Original question"
assert req.pattern == "supervisor"
assert req.session_id == "parent-1"
assert req.correlation_id == "" # Must be empty
assert req.streaming == True
assert req.task_type == "risk-assessment"
assert req.framing == "Assess risks"
def test_synthesis_step_in_history(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 2,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.record_completion("corr-1", "goal-b", "answer-b")
req = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# Last history step should be the synthesis step
assert len(req.history) >= 1
synth_step = req.history[-1]
assert synth_step.step_type == "synthesise"
assert synth_step.subagent_results == {
"goal-a": "answer-a",
"goal-b": "answer-b",
}
def test_consumes_correlation_entry(self):
agg = Aggregator()
template = _make_request()
agg.register_fanout("corr-1", "parent-1", 1,
request_template=template)
agg.record_completion("corr-1", "goal-a", "answer-a")
agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# Entry should be removed
assert "corr-1" not in agg.correlations
def test_raises_for_unknown_correlation(self):
agg = Aggregator()
with pytest.raises(RuntimeError, match="No results"):
agg.build_synthesis_request(
"unknown", "question", "user", "default",
)
class TestCleanupStale:
def test_removes_entries_older_than_timeout(self):
agg = Aggregator(timeout=1)
agg.register_fanout("corr-1", "parent-1", 2)
# Backdate the creation time
agg.correlations["corr-1"]["created_at"] = time.time() - 2
stale = agg.cleanup_stale()
assert "corr-1" in stale
assert "corr-1" not in agg.correlations
def test_keeps_recent_entries(self):
agg = Aggregator(timeout=300)
agg.register_fanout("corr-1", "parent-1", 2)
stale = agg.cleanup_stale()
assert stale == []
assert "corr-1" in agg.correlations
def test_mixed_stale_and_fresh(self):
agg = Aggregator(timeout=1)
agg.register_fanout("stale", "parent-1", 2)
agg.register_fanout("fresh", "parent-2", 2)
agg.correlations["stale"]["created_at"] = time.time() - 2
stale = agg.cleanup_stale()
assert "stale" in stale
assert "stale" not in agg.correlations
assert "fresh" in agg.correlations

View file

@ -0,0 +1,122 @@
"""
Tests that streaming callbacks set message_id on AgentResponse.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.pattern_base import PatternBase
from trustgraph.schema import AgentResponse
@pytest.fixture
def pattern():
processor = MagicMock()
return PatternBase(processor)
class TestThinkCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_think_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/thought"
think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id)
await think("hello", is_final=False)
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "thought"
@pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/thought"
think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id)
await think("hello")
assert responses[0].message_id == msg_id
assert responses[0].end_of_message is True
class TestObserveCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_observe_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/observation"
observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id)
await observe("result", is_final=True)
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "observation"
class TestAnswerCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_answer_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id)
await answer("the answer")
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "answer"
@pytest.mark.asyncio
async def test_no_message_id_default(self, pattern):
responses = []
async def capture(r):
responses.append(r)
answer = pattern.make_answer_callback(capture, streaming=True)
await answer("the answer")
assert responses[0].message_id == ""
class TestSendFinalResponseMessageId:
@pytest.mark.asyncio
async def test_streaming_final_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
await pattern.send_final_response(
capture, streaming=True, answer_text="answer",
message_id=msg_id,
)
# Should get content chunk + end-of-dialog marker
assert all(r.message_id == msg_id for r in responses)
@pytest.mark.asyncio
async def test_non_streaming_final_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
await pattern.send_final_response(
capture, streaming=False, answer_text="answer",
message_id=msg_id,
)
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].end_of_dialog is True

View file

@ -0,0 +1,174 @@
"""
Unit tests for completion dispatch verifies that agent_request() in the
orchestrator service correctly intercepts subagent completion messages and
routes them to _handle_subagent_completion.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.schema import AgentRequest, AgentStep
from trustgraph.agent.orchestrator.aggregator import Aggregator
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
return AgentRequest(**defaults)
def _make_completion_request(correlation_id, goal, answer):
"""Build a completion request as emit_subagent_completion would."""
step = AgentStep(
thought="Subagent completed",
action="complete",
arguments={},
observation=answer,
step_type="subagent-completion",
)
return _make_request(
correlation_id=correlation_id,
parent_session_id="parent-sess",
subagent_goal=goal,
expected_siblings=2,
history=[step],
)
class TestCompletionDetection:
"""Test that completion messages are correctly identified."""
def test_is_completion_when_correlation_id_and_step_type(self):
req = _make_completion_request("corr-1", "goal-a", "answer-a")
has_correlation = bool(getattr(req, 'correlation_id', ''))
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in req.history
)
assert has_correlation
assert is_completion
def test_not_completion_without_correlation_id(self):
step = AgentStep(
step_type="subagent-completion",
observation="answer",
)
req = _make_request(
correlation_id="",
history=[step],
)
has_correlation = bool(getattr(req, 'correlation_id', ''))
assert not has_correlation
def test_not_completion_without_step_type(self):
step = AgentStep(
step_type="react",
observation="answer",
)
req = _make_request(
correlation_id="corr-1",
history=[step],
)
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in req.history
)
assert not is_completion
def test_not_completion_with_empty_history(self):
req = _make_request(
correlation_id="corr-1",
history=[],
)
assert not req.history
class TestAggregatorIntegration:
"""Test the aggregator flow as used by _handle_subagent_completion."""
def test_full_completion_flow(self):
"""Simulates the flow: register, record completions, build synthesis."""
agg = Aggregator()
template = _make_request(
question="Original question",
streaming=True,
task_type="risk-assessment",
framing="Assess risks",
session_id="parent-sess",
)
# Register fan-out
agg.register_fanout("corr-1", "parent-sess", 2,
request_template=template)
# First completion — not all done
all_done = agg.record_completion(
"corr-1", "goal-a", "answer-a",
)
assert all_done is False
# Second completion — all done
all_done = agg.record_completion(
"corr-1", "goal-b", "answer-b",
)
assert all_done is True
# Peek at template
peeked = agg.get_original_request("corr-1")
assert peeked.question == "Original question"
# Build synthesis request
synth = agg.build_synthesis_request(
"corr-1",
original_question="Original question",
user="testuser",
collection="default",
)
# Verify synthesis request
assert synth.pattern == "supervisor"
assert synth.correlation_id == ""
assert synth.session_id == "parent-sess"
assert synth.streaming is True
# Verify synthesis history has results
synth_steps = [
s for s in synth.history
if getattr(s, 'step_type', '') == 'synthesise'
]
assert len(synth_steps) == 1
assert synth_steps[0].subagent_results == {
"goal-a": "answer-a",
"goal-b": "answer-b",
}
def test_synthesis_request_not_detected_as_completion(self):
"""The synthesis request must not be intercepted as a completion."""
agg = Aggregator()
template = _make_request(session_id="parent-sess")
agg.register_fanout("corr-1", "parent-sess", 1,
request_template=template)
agg.record_completion("corr-1", "goal", "answer")
synth = agg.build_synthesis_request(
"corr-1", "question", "user", "default",
)
# correlation_id must be empty so it's not intercepted
assert synth.correlation_id == ""
# Even if we check for completion step, shouldn't match
is_completion = any(
getattr(h, 'step_type', '') == 'subagent-completion'
for h in synth.history
)
assert not is_completion

View file

@ -0,0 +1,177 @@
"""
Unit tests for explainability API parsing verifies that from_triples()
correctly dispatches and parses the new orchestrator entity types.
"""
import pytest
from trustgraph.api.explainability import (
ExplainEntity,
Decomposition,
Finding,
Plan,
StepResult,
Synthesis,
Analysis,
Observation,
Conclusion,
TG_DECOMPOSITION,
TG_FINDING,
TG_PLAN_TYPE,
TG_STEP_RESULT,
TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_OBSERVATION_TYPE,
TG_TOOL_USE,
TG_ANALYSIS,
TG_CONCLUSION,
TG_DOCUMENT,
TG_SUBAGENT_GOAL,
TG_PLAN_STEP,
RDF_TYPE,
)
PROV_ENTITY = "http://www.w3.org/ns/prov#Entity"
def _make_triples(uri, types, extras=None):
"""Build a list of (s, p, o) tuples for testing."""
triples = [(uri, RDF_TYPE, t) for t in types]
if extras:
triples.extend((uri, p, o) for p, o in extras)
return triples
class TestFromTriplesDispatch:
def test_dispatches_decomposition(self):
triples = _make_triples("urn:d", [PROV_ENTITY, TG_DECOMPOSITION])
entity = ExplainEntity.from_triples("urn:d", triples)
assert isinstance(entity, Decomposition)
def test_dispatches_finding(self):
triples = _make_triples("urn:f",
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:f", triples)
assert isinstance(entity, Finding)
def test_dispatches_plan(self):
triples = _make_triples("urn:p", [PROV_ENTITY, TG_PLAN_TYPE])
entity = ExplainEntity.from_triples("urn:p", triples)
assert isinstance(entity, Plan)
def test_dispatches_step_result(self):
triples = _make_triples("urn:sr",
[PROV_ENTITY, TG_STEP_RESULT, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:sr", triples)
assert isinstance(entity, StepResult)
def test_dispatches_synthesis(self):
triples = _make_triples("urn:s",
[PROV_ENTITY, TG_SYNTHESIS, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:s", triples)
assert isinstance(entity, Synthesis)
def test_dispatches_analysis_unchanged(self):
triples = _make_triples("urn:a", [PROV_ENTITY, TG_ANALYSIS])
entity = ExplainEntity.from_triples("urn:a", triples)
assert isinstance(entity, Analysis)
def test_dispatches_analysis_with_tooluse(self):
"""Analysis+ToolUse mixin still dispatches to Analysis."""
triples = _make_triples("urn:a",
[PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE])
entity = ExplainEntity.from_triples("urn:a", triples)
assert isinstance(entity, Analysis)
def test_dispatches_observation(self):
triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE])
entity = ExplainEntity.from_triples("urn:o", triples)
assert isinstance(entity, Observation)
def test_dispatches_conclusion_unchanged(self):
triples = _make_triples("urn:c",
[PROV_ENTITY, TG_CONCLUSION, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:c", triples)
assert isinstance(entity, Conclusion)
def test_finding_takes_precedence_over_synthesis(self):
"""Finding has Answer mixin but should dispatch to Finding, not
Synthesis, because Finding is checked first."""
triples = _make_triples("urn:f",
[PROV_ENTITY, TG_FINDING, TG_ANSWER_TYPE])
entity = ExplainEntity.from_triples("urn:f", triples)
assert isinstance(entity, Finding)
assert not isinstance(entity, Synthesis)
class TestDecompositionParsing:
def test_parses_goals(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION], [
(TG_SUBAGENT_GOAL, "What is X?"),
(TG_SUBAGENT_GOAL, "What is Y?"),
])
entity = Decomposition.from_triples("urn:d", triples)
assert set(entity.goals) == {"What is X?", "What is Y?"}
def test_entity_type_field(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
entity = Decomposition.from_triples("urn:d", triples)
assert entity.entity_type == "decomposition"
def test_empty_goals(self):
triples = _make_triples("urn:d", [TG_DECOMPOSITION])
entity = Decomposition.from_triples("urn:d", triples)
assert entity.goals == []
class TestFindingParsing:
def test_parses_goal_and_document(self):
triples = _make_triples("urn:f", [TG_FINDING, TG_ANSWER_TYPE], [
(TG_SUBAGENT_GOAL, "What is X?"),
(TG_DOCUMENT, "urn:doc/finding"),
])
entity = Finding.from_triples("urn:f", triples)
assert entity.goal == "What is X?"
assert entity.document == "urn:doc/finding"
def test_entity_type_field(self):
triples = _make_triples("urn:f", [TG_FINDING])
entity = Finding.from_triples("urn:f", triples)
assert entity.entity_type == "finding"
class TestPlanParsing:
def test_parses_steps(self):
triples = _make_triples("urn:p", [TG_PLAN_TYPE], [
(TG_PLAN_STEP, "Define X"),
(TG_PLAN_STEP, "Research Y"),
(TG_PLAN_STEP, "Analyse Z"),
])
entity = Plan.from_triples("urn:p", triples)
assert set(entity.steps) == {"Define X", "Research Y", "Analyse Z"}
def test_entity_type_field(self):
triples = _make_triples("urn:p", [TG_PLAN_TYPE])
entity = Plan.from_triples("urn:p", triples)
assert entity.entity_type == "plan"
class TestStepResultParsing:
def test_parses_step_and_document(self):
triples = _make_triples("urn:sr", [TG_STEP_RESULT, TG_ANSWER_TYPE], [
(TG_PLAN_STEP, "Define X"),
(TG_DOCUMENT, "urn:doc/step"),
])
entity = StepResult.from_triples("urn:sr", triples)
assert entity.step == "Define X"
assert entity.document == "urn:doc/step"
def test_entity_type_field(self):
triples = _make_triples("urn:sr", [TG_STEP_RESULT])
entity = StepResult.from_triples("urn:sr", triples)
assert entity.entity_type == "step-result"

View file

@ -0,0 +1,289 @@
"""
Unit tests for the MetaRouter task type identification and pattern selection.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
def _make_config(patterns=None, task_types=None):
"""Build a config dict as the config service would provide."""
config = {}
if patterns:
config["agent-pattern"] = {
pid: json.dumps(pdata) for pid, pdata in patterns.items()
}
if task_types:
config["agent-task-type"] = {
tid: json.dumps(tdata) for tid, tdata in task_types.items()
}
return config
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
def context(service_name):
return client
return context
SAMPLE_PATTERNS = {
"react": {"name": "react", "description": "ReAct pattern"},
"plan-then-execute": {"name": "plan-then-execute", "description": "Plan pattern"},
"supervisor": {"name": "supervisor", "description": "Supervisor pattern"},
}
SAMPLE_TASK_TYPES = {
"general": {
"name": "general",
"description": "General queries",
"valid_patterns": ["react", "plan-then-execute", "supervisor"],
"framing": "",
},
"research": {
"name": "research",
"description": "Research queries",
"valid_patterns": ["react", "plan-then-execute"],
"framing": "Focus on gathering information.",
},
"summarisation": {
"name": "summarisation",
"description": "Summarisation queries",
"valid_patterns": ["react"],
"framing": "Focus on concise synthesis.",
},
}
class TestMetaRouterInit:
def test_defaults_when_no_config(self):
router = MetaRouter()
assert "react" in router.patterns
assert "general" in router.task_types
def test_loads_patterns_from_config(self):
config = _make_config(patterns=SAMPLE_PATTERNS)
router = MetaRouter(config=config)
assert set(router.patterns.keys()) == {"react", "plan-then-execute", "supervisor"}
def test_loads_task_types_from_config(self):
config = _make_config(task_types=SAMPLE_TASK_TYPES)
router = MetaRouter(config=config)
assert set(router.task_types.keys()) == {"general", "research", "summarisation"}
def test_handles_invalid_json_in_config(self):
config = {
"agent-pattern": {"react": "not valid json"},
}
router = MetaRouter(config=config)
assert "react" in router.patterns
assert router.patterns["react"]["name"] == "react"
class TestIdentifyTaskType:
@pytest.mark.asyncio
async def test_skips_llm_when_single_task_type(self):
router = MetaRouter() # Only "general"
context = _make_context("should not be called")
task_type, framing = await router.identify_task_type(
"test question", context,
)
assert task_type == "general"
@pytest.mark.asyncio
async def test_uses_llm_when_multiple_task_types(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("research")
task_type, framing = await router.identify_task_type(
"Research the topic", context,
)
assert task_type == "research"
assert framing == "Focus on gathering information."
@pytest.mark.asyncio
async def test_handles_llm_returning_quoted_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context('"summarisation"')
task_type, _ = await router.identify_task_type(
"Summarise this", context,
)
assert task_type == "summarisation"
@pytest.mark.asyncio
async def test_falls_back_on_unknown_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("nonexistent-type")
task_type, _ = await router.identify_task_type(
"test question", context,
)
assert task_type == DEFAULT_TASK_TYPE
@pytest.mark.asyncio
async def test_falls_back_on_llm_error(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
client = AsyncMock()
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
context = lambda name: client
task_type, _ = await router.identify_task_type(
"test question", context,
)
assert task_type == DEFAULT_TASK_TYPE
class TestSelectPattern:
@pytest.mark.asyncio
async def test_skips_llm_when_single_valid_pattern(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("should not be called")
# summarisation only has ["react"]
pattern = await router.select_pattern(
"Summarise this", "summarisation", context,
)
assert pattern == "react"
@pytest.mark.asyncio
async def test_uses_llm_when_multiple_valid_patterns(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("plan-then-execute")
# research has ["react", "plan-then-execute"]
pattern = await router.select_pattern(
"Research this", "research", context,
)
assert pattern == "plan-then-execute"
@pytest.mark.asyncio
async def test_respects_valid_patterns_constraint(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
# LLM returns supervisor, but research doesn't allow it
context = _make_context("supervisor")
pattern = await router.select_pattern(
"Research this", "research", context,
)
# Should fall back to first valid pattern
assert pattern == "react"
@pytest.mark.asyncio
async def test_falls_back_on_llm_error(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
client = AsyncMock()
client.prompt = AsyncMock(side_effect=RuntimeError("LLM down"))
context = lambda name: client
# general has ["react", "plan-then-execute", "supervisor"]
pattern = await router.select_pattern(
"test", "general", context,
)
# Falls back to first valid pattern
assert pattern == "react"
@pytest.mark.asyncio
async def test_falls_back_to_default_for_unknown_task_type(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
context = _make_context("react")
# Unknown task type — valid_patterns falls back to all patterns
pattern = await router.select_pattern(
"test", "unknown-type", context,
)
assert pattern == "react"
class TestRoute:
@pytest.mark.asyncio
async def test_full_routing_pipeline(self):
config = _make_config(
patterns=SAMPLE_PATTERNS,
task_types=SAMPLE_TASK_TYPES,
)
router = MetaRouter(config=config)
# Mock context where prompt returns different values per call
client = AsyncMock()
call_count = 0
async def mock_prompt(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
client.prompt = mock_prompt
context = lambda name: client
pattern, task_type, framing = await router.route(
"Research the relationships", context,
)
assert task_type == "research"
assert pattern == "plan-then-execute"
assert framing == "Focus on gathering information."

View file

@ -0,0 +1,132 @@
"""
Tests for the on_action callback in react() verifies that it fires
after action selection but before tool execution.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.types import Action, Final, Tool, Argument
class TestOnActionCallback:
@pytest.mark.asyncio
async def test_on_action_called_for_tool_use(self):
"""on_action fires when react() selects a tool (not Final)."""
call_log = []
async def fake_on_action(act):
call_log.append(("on_action", act.name))
# Tool that records when it's invoked
async def tool_invoke(**kwargs):
call_log.append(("tool_invoke",))
return "tool result"
tool_impl = MagicMock()
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
tools = {
"search": Tool(
name="search",
description="Search",
implementation=tool_impl,
arguments=[Argument(name="query", type="string", description="q")],
config={},
),
}
agent = AgentManager(tools=tools)
# Mock reason() to return an Action
action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="")
agent.reason = AsyncMock(return_value=action)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
on_action=fake_on_action,
)
# on_action should fire before tool_invoke
assert len(call_log) == 2
assert call_log[0] == ("on_action", "search")
assert call_log[1] == ("tool_invoke",)
@pytest.mark.asyncio
async def test_on_action_not_called_for_final(self):
"""on_action does not fire when react() returns Final."""
called = []
async def fake_on_action(act):
called.append(act)
agent = AgentManager(tools={})
agent.reason = AsyncMock(
return_value=Final(thought="done", final="answer")
)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
result = await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
on_action=fake_on_action,
)
assert isinstance(result, Final)
assert len(called) == 0
@pytest.mark.asyncio
async def test_on_action_none_accepted(self):
"""react() works fine when on_action is None (default)."""
async def tool_invoke(**kwargs):
return "result"
tool_impl = MagicMock()
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
tools = {
"search": Tool(
name="search",
description="Search",
implementation=tool_impl,
arguments=[],
config={},
),
}
agent = AgentManager(tools=tools)
agent.reason = AsyncMock(
return_value=Action(thought="t", name="search", arguments={}, observation="")
)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
result = await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
# on_action not passed — defaults to None
)
assert isinstance(result, Action)
assert result.observation == "result"

View file

@ -0,0 +1,74 @@
"""
Tests that _parse_chunk propagates message_id from wire format
to AgentThought, AgentObservation, and AgentAnswer.
"""
import pytest
from trustgraph.api.socket_client import SocketClient
from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer
@pytest.fixture
def client():
# We only need _parse_chunk — don't connect
c = object.__new__(SocketClient)
return c
class TestParseChunkMessageId:
def test_thought_message_id(self, client):
resp = {
"chunk_type": "thought",
"content": "thinking...",
"end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentThought)
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought"
def test_observation_message_id(self, client):
resp = {
"chunk_type": "observation",
"content": "result",
"end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentObservation)
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation"
def test_answer_message_id(self, client):
resp = {
"chunk_type": "answer",
"content": "the answer",
"end_of_message": False,
"end_of_dialog": False,
"message_id": "urn:trustgraph:agent:sess/final",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentAnswer)
assert chunk.message_id == "urn:trustgraph:agent:sess/final"
def test_thought_missing_message_id(self, client):
resp = {
"chunk_type": "thought",
"content": "thinking...",
"end_of_message": False,
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentThought)
assert chunk.message_id == ""
def test_answer_missing_message_id(self, client):
resp = {
"chunk_type": "answer",
"content": "answer",
"end_of_message": True,
"end_of_dialog": True,
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentAnswer)
assert chunk.message_id == ""

View file

@ -0,0 +1,144 @@
"""
Unit tests for PatternBase subagent helpers is_subagent() and
emit_subagent_completion().
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from dataclasses import dataclass
from trustgraph.schema import AgentRequest
from trustgraph.agent.orchestrator.pattern_base import PatternBase
@dataclass
class MockProcessor:
"""Minimal processor mock for PatternBase."""
pass
def _make_request(**kwargs):
defaults = dict(
question="Test question",
user="testuser",
collection="default",
)
defaults.update(kwargs)
return AgentRequest(**defaults)
def _make_pattern():
return PatternBase(MockProcessor())
class TestIsSubagent:
def test_returns_true_when_correlation_id_set(self):
pattern = _make_pattern()
request = _make_request(correlation_id="corr-123")
assert pattern.is_subagent(request) is True
def test_returns_false_when_correlation_id_empty(self):
pattern = _make_pattern()
request = _make_request(correlation_id="")
assert pattern.is_subagent(request) is False
def test_returns_false_when_correlation_id_missing(self):
pattern = _make_pattern()
request = _make_request()
assert pattern.is_subagent(request) is False
class TestEmitSubagentCompletion:
@pytest.mark.asyncio
async def test_calls_next_with_completion_request(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "The answer is Y",
)
next_fn.assert_called_once()
completion_req = next_fn.call_args[0][0]
assert isinstance(completion_req, AgentRequest)
@pytest.mark.asyncio
async def test_completion_has_correct_step_type(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="What is X?",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer text",
)
completion_req = next_fn.call_args[0][0]
assert len(completion_req.history) == 1
step = completion_req.history[0]
assert step.step_type == "subagent-completion"
@pytest.mark.asyncio
async def test_completion_carries_answer_in_observation(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="What is X?",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "The answer is Y",
)
completion_req = next_fn.call_args[0][0]
step = completion_req.history[0]
assert step.observation == "The answer is Y"
@pytest.mark.asyncio
async def test_completion_preserves_correlation_fields(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
parent_session_id="parent-sess",
subagent_goal="What is X?",
expected_siblings=4,
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer",
)
completion_req = next_fn.call_args[0][0]
assert completion_req.correlation_id == "corr-123"
assert completion_req.parent_session_id == "parent-sess"
assert completion_req.subagent_goal == "What is X?"
assert completion_req.expected_siblings == 4
@pytest.mark.asyncio
async def test_completion_has_empty_pattern(self):
pattern = _make_pattern()
request = _make_request(
correlation_id="corr-123",
subagent_goal="goal",
)
next_fn = AsyncMock()
await pattern.emit_subagent_completion(
request, next_fn, "answer",
)
completion_req = next_fn.call_args[0][0]
assert completion_req.pattern == ""

View file

@ -0,0 +1,226 @@
"""
Unit tests for orchestrator provenance triple builders.
"""
import pytest
from trustgraph.provenance import (
agent_decomposition_triples,
agent_finding_triples,
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SYNTHESIS, TG_ANSWER_TYPE, TG_DOCUMENT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
)
def _triple_set(triples):
"""Convert triples to a set of (s_iri, p_iri, o_value) for easy assertion."""
result = set()
for t in triples:
s = t.s.iri
p = t.p.iri
o = t.o.iri if t.o.iri else t.o.value
result.add((s, p, o))
return result
def _has_type(triples, uri, rdf_type):
"""Check if a URI has a given rdf:type in the triples."""
return (uri, RDF_TYPE, rdf_type) in _triple_set(triples)
def _get_values(triples, uri, predicate):
"""Get all object values for a given subject + predicate."""
ts = _triple_set(triples)
return [o for s, p, o in ts if s == uri and p == predicate]
class TestDecompositionTriples:
def test_has_correct_types(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a", "goal-b"],
)
assert _has_type(triples, "urn:decompose", PROV_ENTITY)
assert _has_type(triples, "urn:decompose", TG_DECOMPOSITION)
def test_not_answer_type(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a"],
)
assert not _has_type(triples, "urn:decompose", TG_ANSWER_TYPE)
def test_links_to_session(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["goal-a"],
)
ts = _triple_set(triples)
assert ("urn:decompose", PROV_WAS_DERIVED_FROM, "urn:session") in ts
def test_includes_goals(self):
goals = ["What is X?", "What is Y?", "What is Z?"]
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", goals,
)
values = _get_values(triples, "urn:decompose", TG_SUBAGENT_GOAL)
assert set(values) == set(goals)
def test_label_includes_count(self):
triples = agent_decomposition_triples(
"urn:decompose", "urn:session", ["a", "b", "c"],
)
labels = _get_values(triples, "urn:decompose", RDFS_LABEL)
assert any("3" in label for label in labels)
class TestFindingTriples:
def test_has_correct_types(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
assert _has_type(triples, "urn:finding", PROV_ENTITY)
assert _has_type(triples, "urn:finding", TG_FINDING)
assert _has_type(triples, "urn:finding", TG_ANSWER_TYPE)
def test_links_to_decomposition(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
ts = _triple_set(triples)
assert ("urn:finding", PROV_WAS_DERIVED_FROM, "urn:decompose") in ts
def test_includes_goal(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "What is X?",
)
values = _get_values(triples, "urn:finding", TG_SUBAGENT_GOAL)
assert "What is X?" in values
def test_includes_document_when_provided(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "goal",
document_id="urn:doc/1",
)
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
assert "urn:doc/1" in values
def test_no_document_when_none(self):
triples = agent_finding_triples(
"urn:finding", "urn:decompose", "goal",
)
values = _get_values(triples, "urn:finding", TG_DOCUMENT)
assert values == []
class TestPlanTriples:
def test_has_correct_types(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
assert _has_type(triples, "urn:plan", PROV_ENTITY)
assert _has_type(triples, "urn:plan", TG_PLAN_TYPE)
def test_not_answer_type(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
assert not _has_type(triples, "urn:plan", TG_ANSWER_TYPE)
def test_links_to_session(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["step-a"],
)
ts = _triple_set(triples)
assert ("urn:plan", PROV_WAS_DERIVED_FROM, "urn:session") in ts
def test_includes_steps(self):
steps = ["Define X", "Research Y", "Analyse Z"]
triples = agent_plan_triples(
"urn:plan", "urn:session", steps,
)
values = _get_values(triples, "urn:plan", TG_PLAN_STEP)
assert set(values) == set(steps)
def test_label_includes_count(self):
triples = agent_plan_triples(
"urn:plan", "urn:session", ["a", "b"],
)
labels = _get_values(triples, "urn:plan", RDFS_LABEL)
assert any("2" in label for label in labels)
class TestStepResultTriples:
def test_has_correct_types(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
assert _has_type(triples, "urn:step", PROV_ENTITY)
assert _has_type(triples, "urn:step", TG_STEP_RESULT)
assert _has_type(triples, "urn:step", TG_ANSWER_TYPE)
def test_links_to_plan(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
ts = _triple_set(triples)
assert ("urn:step", PROV_WAS_DERIVED_FROM, "urn:plan") in ts
def test_includes_goal(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "Define X",
)
values = _get_values(triples, "urn:step", TG_PLAN_STEP)
assert "Define X" in values
def test_includes_document_when_provided(self):
triples = agent_step_result_triples(
"urn:step", "urn:plan", "goal",
document_id="urn:doc/step",
)
values = _get_values(triples, "urn:step", TG_DOCUMENT)
assert "urn:doc/step" in values
class TestSynthesisTriples:
def test_has_correct_types(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
)
assert _has_type(triples, "urn:synthesis", PROV_ENTITY)
assert _has_type(triples, "urn:synthesis", TG_SYNTHESIS)
assert _has_type(triples, "urn:synthesis", TG_ANSWER_TYPE)
def test_links_to_previous(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:last-finding",
)
ts = _triple_set(triples)
assert ("urn:synthesis", PROV_WAS_DERIVED_FROM,
"urn:last-finding") in ts
def test_includes_document_when_provided(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
document_id="urn:doc/synthesis",
)
values = _get_values(triples, "urn:synthesis", TG_DOCUMENT)
assert "urn:doc/synthesis" in values
def test_label_is_synthesis(self):
triples = agent_synthesis_triples(
"urn:synthesis", "urn:previous",
)
labels = _get_values(triples, "urn:synthesis", RDFS_LABEL)
assert "Synthesis" in labels

View file

@ -0,0 +1,323 @@
"""
Tests for AsyncProcessor config notify pattern:
- register_config_handler with types filtering
- on_config_notify version comparison and type matching
- fetch_config with short-lived client
- fetch_and_apply_config retry logic
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, Mock
from trustgraph.schema import Term, IRI, LITERAL
# Patch heavy dependencies before importing AsyncProcessor
@pytest.fixture
def processor():
"""Create an AsyncProcessor with mocked dependencies."""
with patch('trustgraph.base.async_processor.get_pubsub') as mock_pubsub, \
patch('trustgraph.base.async_processor.Consumer') as mock_consumer, \
patch('trustgraph.base.async_processor.ProcessorMetrics') as mock_pm, \
patch('trustgraph.base.async_processor.ConsumerMetrics') as mock_cm:
mock_pubsub.return_value = MagicMock()
mock_consumer.return_value = MagicMock()
mock_pm.return_value = MagicMock()
mock_cm.return_value = MagicMock()
from trustgraph.base.async_processor import AsyncProcessor
p = AsyncProcessor(
id="test-processor",
taskgroup=AsyncMock(),
)
return p
class TestRegisterConfigHandler:
def test_register_without_types(self, processor):
handler = AsyncMock()
processor.register_config_handler(handler)
assert len(processor.config_handlers) == 1
assert processor.config_handlers[0]["handler"] is handler
assert processor.config_handlers[0]["types"] is None
def test_register_with_types(self, processor):
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
assert processor.config_handlers[0]["types"] == {"prompt"}
def test_register_multiple_types(self, processor):
handler = AsyncMock()
processor.register_config_handler(
handler, types=["schema", "collection"]
)
assert processor.config_handlers[0]["types"] == {
"schema", "collection"
}
def test_register_multiple_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
assert len(processor.config_handlers) == 2
class TestOnConfigNotify:
@pytest.mark.asyncio
async def test_skip_old_version(self, processor):
processor.config_version = 5
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=3, types=["prompt"])
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@pytest.mark.asyncio
async def test_skip_same_version(self, processor):
processor.config_version = 5
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=5, types=["prompt"])
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
@pytest.mark.asyncio
async def test_skip_irrelevant_types(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
msg = Mock()
msg.value.return_value = Mock(version=2, types=["schema"])
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
# Version should still be updated
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_fetch_on_relevant_type(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler, types=["prompt"])
# Mock fetch_config
mock_config = {"prompt": {"key": "value"}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
assert processor.config_version == 2
@pytest.mark.asyncio
async def test_handler_without_types_always_called(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler) # No types = all
mock_config = {"anything": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["whatever"])
await processor.on_config_notify(msg, None, None)
handler.assert_called_once_with(mock_config, 2)
@pytest.mark.asyncio
async def test_mixed_handlers_type_filtering(self, processor):
processor.config_version = 1
prompt_handler = AsyncMock()
schema_handler = AsyncMock()
all_handler = AsyncMock()
processor.register_config_handler(prompt_handler, types=["prompt"])
processor.register_config_handler(schema_handler, types=["schema"])
processor.register_config_handler(all_handler)
mock_config = {"prompt": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
await processor.on_config_notify(msg, None, None)
prompt_handler.assert_called_once()
schema_handler.assert_not_called()
all_handler.assert_called_once()
@pytest.mark.asyncio
async def test_empty_types_invokes_all(self, processor):
"""Empty types list (startup signal) should invoke all handlers."""
processor.config_version = 1
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
mock_config = {}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 2)
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=[])
await processor.on_config_notify(msg, None, None)
h1.assert_called_once()
h2.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_failure_handled(self, processor):
processor.config_version = 1
handler = AsyncMock()
processor.register_config_handler(handler)
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
side_effect=RuntimeError("Connection failed")
):
msg = Mock()
msg.value.return_value = Mock(version=2, types=["prompt"])
# Should not raise
await processor.on_config_notify(msg, None, None)
handler.assert_not_called()
class TestFetchConfig:
@pytest.mark.asyncio
async def test_fetch_returns_config_and_version(self, processor):
mock_resp = Mock()
mock_resp.error = None
mock_resp.config = {"prompt": {"key": "val"}}
mock_resp.version = 42
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
config, version = await processor.fetch_config()
assert config == {"prompt": {"key": "val"}}
assert version == 42
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_raises_on_error_response(self, processor):
mock_resp = Mock()
mock_resp.error = Mock(message="not found")
mock_resp.config = {}
mock_resp.version = 0
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(RuntimeError, match="Config error"):
await processor.fetch_config()
mock_client.stop.assert_called_once()
@pytest.mark.asyncio
async def test_fetch_stops_client_on_exception(self, processor):
mock_client = AsyncMock()
mock_client.request.side_effect = TimeoutError("timeout")
with patch.object(
processor, '_create_config_client', return_value=mock_client
):
with pytest.raises(TimeoutError):
await processor.fetch_config()
mock_client.stop.assert_called_once()
class TestFetchAndApplyConfig:
@pytest.mark.asyncio
async def test_applies_config_to_all_handlers(self, processor):
h1 = AsyncMock()
h2 = AsyncMock()
processor.register_config_handler(h1, types=["prompt"])
processor.register_config_handler(h2, types=["schema"])
mock_config = {"prompt": {}, "schema": {}}
with patch.object(
processor, 'fetch_config',
new_callable=AsyncMock,
return_value=(mock_config, 10)
):
await processor.fetch_and_apply_config()
# On startup, all handlers are invoked regardless of type
h1.assert_called_once_with(mock_config, 10)
h2.assert_called_once_with(mock_config, 10)
assert processor.config_version == 10
@pytest.mark.asyncio
async def test_retries_on_failure(self, processor):
call_count = 0
mock_config = {"prompt": {}}
async def mock_fetch():
nonlocal call_count
call_count += 1
if call_count < 3:
raise RuntimeError("not ready")
return mock_config, 5
with patch.object(processor, 'fetch_config', side_effect=mock_fetch), \
patch('asyncio.sleep', new_callable=AsyncMock):
await processor.fetch_and_apply_config()
assert call_count == 3
assert processor.config_version == 5

View file

@ -35,7 +35,9 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_async_init.assert_called_once()
# Verify register_config_handler was called with the correct handler
mock_register_config.assert_called_once_with(processor.on_configure_flows)
mock_register_config.assert_called_once_with(
processor.on_configure_flows, types=["active-flow"]
)
# Verify FlowProcessor-specific initialization
assert hasattr(processor, 'flows')

View file

@ -61,23 +61,21 @@ async def test_subscriber_deferred_acknowledgment_success():
max_size=10,
backpressure_strategy="block"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
# Create queue for subscription
queue = await subscriber.subscribe("test-queue")
# Create mock message with matching queue name
msg = create_mock_message("test-queue", {"data": "test"})
# Process message
await subscriber._process_message(msg)
# Should acknowledge successful delivery
mock_consumer.acknowledge.assert_called_once_with(msg)
mock_consumer.negative_acknowledge.assert_not_called()
# Message should be in queue
assert not queue.empty()
received_msg = await queue.get()
@ -108,9 +106,7 @@ async def test_subscriber_dropped_message_still_acks():
max_size=1, # Very small queue
backpressure_strategy="drop_new"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
# Create queue and fill it
queue = await subscriber.subscribe("test-queue")
@ -151,9 +147,7 @@ async def test_subscriber_orphaned_message_acks():
max_size=10,
backpressure_strategy="block"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
# Don't create any queues - message will be orphaned
# This simulates a response arriving after the waiter has unsubscribed
@ -189,9 +183,7 @@ async def test_subscriber_backpressure_strategies():
max_size=2,
backpressure_strategy="drop_oldest"
)
# Start subscriber to initialize consumer
await subscriber.start()
subscriber.consumer = mock_consumer
queue = await subscriber.subscribe("test-queue")

View file

@ -24,8 +24,8 @@ class MockAsyncProcessor:
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"""Test Recursive chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization"""
@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override"""
@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 2000 # Should use overridden value
assert chunk_overlap == 100 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override"""
@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
}.get(param)
@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 200 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
}.get(param)
@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1500 # Should use overridden value
assert chunk_overlap == 150 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
# Mock message with TextDocument
mock_message = MagicMock()
@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 1500,
"chunk-overlap": 150,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(param)
}.get(name)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
mock_flow.parameters.get.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -24,8 +24,8 @@ class MockAsyncProcessor:
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"""Test Token chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization"""
@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override"""
@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 400 # Should use overridden value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override"""
@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
}.get(param)
@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 25 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
}.get(param)
@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 350 # Should use overridden value
assert chunk_overlap == 30 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid librarian producer interactions
processor.save_child_document = AsyncMock(return_value="chunk-id")
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
# Mock message with TextDocument
mock_message = MagicMock()
@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 400,
"chunk-overlap": 40,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(param)
}.get(name)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
mock_flow.parameters.get.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
"""Test that token chunker has different defaults than recursive chunker"""

View file

@ -21,17 +21,15 @@ class TestSyncDocumentEmbeddingsClient:
# Act
client = DocumentEmbeddingsClient(
log_level=1,
subscriber="test-subscriber",
input_queue="test-input",
output_queue="test-output",
pulsar_host="pulsar://test:6650",
pulsar_api_key="test-key"
)
# Assert
mock_base_init.assert_called_once_with(
log_level=1,
subscriber="test-subscriber",
input_queue="test-input",
output_queue="test-output",

View file

@ -81,9 +81,8 @@ class TestTaskGroupConcurrency:
# Track how many consume_from_queue calls are made
call_count = 0
original_running = True
async def mock_consume():
async def mock_consume(backend_consumer, executor=None):
nonlocal call_count
call_count += 1
# Wait a bit to let all tasks start, then signal stop
@ -107,7 +106,7 @@ class TestTaskGroupConcurrency:
consumer = _make_consumer(concurrency=1)
call_count = 0
async def mock_consume():
async def mock_consume(backend_consumer, executor=None):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
@ -147,7 +146,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 2
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@ -166,7 +165,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
consumer.consumer.acknowledge.assert_not_called()
@ -185,7 +184,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 1
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
@ -197,7 +196,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@ -219,7 +218,7 @@ class TestMetricsIntegration:
mock_metrics.record_time.return_value.__exit__ = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("success")
@ -235,7 +234,7 @@ class TestMetricsIntegration:
mock_metrics = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("error")
@ -261,7 +260,7 @@ class TestMetricsIntegration:
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.rate_limit.assert_called_once()
@ -294,9 +293,8 @@ class TestPollTimeout:
raise type('Timeout', (Exception,), {})("timeout")
mock_pulsar_consumer.receive = capture_receive
consumer.consumer = mock_pulsar_consumer
await consumer.consume_from_queue()
await consumer.consume_from_queue(mock_pulsar_consumer)
assert received_kwargs.get("timeout_millis") == 100

View file

@ -25,8 +25,8 @@ class MockAsyncProcessor:
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
"""Test Mistral OCR processor functionality"""
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_with_api_key(
@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_without_api_key(
self, mock_producer, mock_consumer
@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
Processor(**config)
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_ocr_single_chunk(
@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
)
mock_mistral.ocr.process.assert_called_once()
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success(
@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
]
# Mock save_child_document
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
with patch.object(processor, 'ocr', return_value=ocr_result):
await processor.on_message(mock_msg, None, mock_flow)

View file

@ -24,12 +24,10 @@ class MockAsyncProcessor:
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"""Test PDF decoder processor functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_processor_initialization(self, mock_producer, mock_consumer):
"""Test PDF decoder processor initialization"""
config = {
'id': 'test-pdf-decoder',
@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing"""
# Mock PDF content
pdf_content = b"fake pdf content"
@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow)
@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF"""
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
mock_output_flow.send.assert_not_called()
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF"""
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow)

View file

@ -142,8 +142,8 @@ class TestPageBasedFormats:
class TestUniversalProcessor(IsolatedAsyncioTestCase):
"""Test universal decoder processor."""
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization(
self, mock_producer, mock_consumer
@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_custom_strategy(
self, mock_producer, mock_consumer
@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert processor.partition_strategy == "hi_res"
assert processor.section_strategy_name == "heading"
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_group_by_page(self, mock_producer, mock_consumer):
"""Test page grouping of elements."""
@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert result[1][0] == 2
assert len(result[1][1]) == 1
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_inline_non_page(
@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
}.get(name))
# Mock save_child_document and magic
processor.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "text/markdown"
@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert call_args.document_id.startswith("urn:section:")
assert call_args.text == b""
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_page_based(
@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow,
}.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf"
@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
call_args = mock_output_flow.send.call_args_list[0][0][0]
assert call_args.document_id.startswith("urn:page:")
@patch('trustgraph.decoding.universal.processor.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer')
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_images_stored_not_emitted(
@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow,
}.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf"
@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert mock_triples_flow.send.call_count == 2
# save_child_document called twice (page + image)
assert processor.save_child_document.call_count == 2
assert processor.librarian.save_child_document.call_count == 2
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args):

View file

@ -5,7 +5,7 @@ Tests for Gateway Config Receiver
import pytest
import asyncio
import json
from unittest.mock import Mock, patch, Mock, MagicMock
from unittest.mock import Mock, patch, MagicMock, AsyncMock
import uuid
from trustgraph.gateway.config.receiver import ConfigReceiver
@ -23,174 +23,237 @@ class TestConfigReceiver:
def test_config_receiver_initialization(self):
"""Test ConfigReceiver initialization"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
assert config_receiver.backend == mock_backend
assert config_receiver.flow_handlers == []
assert config_receiver.flows == {}
assert config_receiver.config_version == 0
def test_add_handler(self):
"""Test adding flow handlers"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock()
handler2 = Mock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
assert len(config_receiver.flow_handlers) == 2
assert handler1 in config_receiver.flow_handlers
assert handler2 in config_receiver.flow_handlers
@pytest.mark.asyncio
async def test_on_config_with_new_flows(self):
"""Test on_config method with new flows"""
async def test_on_config_notify_new_version(self):
"""Test on_config_notify triggers fetch for newer version"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Track calls manually instead of using AsyncMock
start_flow_calls = []
async def mock_start_flow(*args):
start_flow_calls.append(args)
config_receiver.start_flow = mock_start_flow
# Create mock message with flows
config_receiver.config_version = 1
# Mock fetch_and_apply
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with newer version
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow1": '{"name": "test_flow_1", "steps": []}',
"flow2": '{"name": "test_flow_2", "steps": []}'
}
}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify flows were added
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []}
assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []}
# Verify start_flow was called for each new flow
assert len(start_flow_calls) == 2
assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls
assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls
mock_msg.value.return_value = Mock(version=2, types=["flow"])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1
@pytest.mark.asyncio
async def test_on_config_with_removed_flows(self):
"""Test on_config method with removed flows"""
async def test_on_config_notify_old_version_ignored(self):
"""Test on_config_notify ignores older versions"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1", "steps": []},
"flow2": {"name": "test_flow_2", "steps": []}
}
# Track calls manually instead of using AsyncMock
stop_flow_calls = []
async def mock_stop_flow(*args):
stop_flow_calls.append(args)
config_receiver.stop_flow = mock_stop_flow
# Create mock message with only flow1 (flow2 removed)
config_receiver.config_version = 5
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with older version
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow1": '{"name": "test_flow_1", "steps": []}'
}
}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify flow2 was removed
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
# Verify stop_flow was called for removed flow
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []})
mock_msg.value.return_value = Mock(version=3, types=["flow"])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 0
@pytest.mark.asyncio
async def test_on_config_with_no_flows(self):
"""Test on_config method with no flows in config"""
async def test_on_config_notify_irrelevant_types_ignored(self):
"""Test on_config_notify ignores types the gateway doesn't care about"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock the start_flow and stop_flow methods with async functions
async def mock_start_flow(*args):
pass
async def mock_stop_flow(*args):
pass
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
# Create mock message without flows
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
# Create notify message with non-flow type
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify no flows were added
assert config_receiver.flows == {}
# Since no flows were in the config, the flow methods shouldn't be called
# (We can't easily assert this with simple async functions, but the test
# passes if no exceptions are thrown)
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
await config_receiver.on_config_notify(mock_msg, None, None)
# Version should be updated but no fetch
assert len(fetch_calls) == 0
assert config_receiver.config_version == 2
@pytest.mark.asyncio
async def test_on_config_exception_handling(self):
"""Test on_config method handles exceptions gracefully"""
async def test_on_config_notify_flow_type_triggers_fetch(self):
"""Test on_config_notify fetches for flow-related types"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create mock message that will cause an exception
config_receiver.config_version = 1
fetch_calls = []
async def mock_fetch(**kwargs):
fetch_calls.append(kwargs)
config_receiver.fetch_and_apply = mock_fetch
for type_name in ["flow", "active-flow"]:
fetch_calls.clear()
config_receiver.config_version = 1
mock_msg = Mock()
mock_msg.value.return_value = Mock(version=2, types=[type_name])
await config_receiver.on_config_notify(mock_msg, None, None)
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
@pytest.mark.asyncio
async def test_on_config_notify_exception_handling(self):
"""Test on_config_notify handles exceptions gracefully"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Create notify message that causes an exception
mock_msg = Mock()
mock_msg.value.side_effect = Exception("Test exception")
# This should not raise an exception
await config_receiver.on_config(mock_msg, None, None)
# Verify flows remain empty
# Should not raise
await config_receiver.on_config_notify(mock_msg, None, None)
@pytest.mark.asyncio
async def test_fetch_and_apply_with_new_flows(self):
"""Test fetch_and_apply starts new flows"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock config_client
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}',
"flow2": '{"name": "test_flow_2"}'
}
}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver.config_client = mock_client
start_flow_calls = []
async def mock_start_flow(id, flow):
start_flow_calls.append((id, flow))
config_receiver.start_flow = mock_start_flow
await config_receiver.fetch_and_apply()
assert config_receiver.config_version == 5
assert "flow1" in config_receiver.flows
assert "flow2" in config_receiver.flows
assert len(start_flow_calls) == 2
@pytest.mark.asyncio
async def test_fetch_and_apply_with_removed_flows(self):
"""Test fetch_and_apply stops removed flows"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
config_receiver.flows = {
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
}
# Config now only has flow1
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow1": '{"name": "test_flow_1"}'
}
}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver.config_client = mock_client
stop_flow_calls = []
async def mock_stop_flow(id, flow):
stop_flow_calls.append((id, flow))
config_receiver.stop_flow = mock_stop_flow
await config_receiver.fetch_and_apply()
assert "flow1" in config_receiver.flows
assert "flow2" not in config_receiver.flows
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0][0] == "flow2"
@pytest.mark.asyncio
async def test_fetch_and_apply_with_no_flows(self):
"""Test fetch_and_apply with empty config"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 1
mock_resp.config = {}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver.config_client = mock_client
await config_receiver.fetch_and_apply()
assert config_receiver.flows == {}
assert config_receiver.config_version == 1
@pytest.mark.asyncio
async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handlers
handler1 = Mock()
handler1.start_flow = Mock()
handler2 = Mock()
handler2.start_flow = Mock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.start_flow("flow1", flow_data)
# Verify all handlers were called
handler1.start_flow.assert_called_once_with("flow1", flow_data)
handler2.start_flow.assert_called_once_with("flow1", flow_data)
@ -199,19 +262,17 @@ class TestConfigReceiver:
"""Test start_flow method handles handler exceptions"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handler that raises exception
handler = Mock()
handler.start_flow = Mock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# This should not raise an exception
# Should not raise
await config_receiver.start_flow("flow1", flow_data)
# Verify handler was called
handler.start_flow.assert_called_once_with("flow1", flow_data)
@pytest.mark.asyncio
@ -219,21 +280,19 @@ class TestConfigReceiver:
"""Test stop_flow method with multiple handlers"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handlers
handler1 = Mock()
handler1.stop_flow = Mock()
handler2 = Mock()
handler2.stop_flow = Mock()
config_receiver.add_handler(handler1)
config_receiver.add_handler(handler2)
flow_data = {"name": "test_flow", "steps": []}
await config_receiver.stop_flow("flow1", flow_data)
# Verify all handlers were called
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
@ -242,167 +301,77 @@ class TestConfigReceiver:
"""Test stop_flow method handles handler exceptions"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Add mock handler that raises exception
handler = Mock()
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
config_receiver.add_handler(handler)
flow_data = {"name": "test_flow", "steps": []}
# This should not raise an exception
# Should not raise
await config_receiver.stop_flow("flow1", flow_data)
# Verify handler was called
handler.stop_flow.assert_called_once_with("flow1", flow_data)
@pytest.mark.asyncio
async def test_config_loader_creates_consumer(self):
"""Test config_loader method creates Pulsar consumer"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Temporarily restore the real config_loader for this test
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
# Mock Consumer class
with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \
patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value = "test-uuid"
mock_consumer = Mock()
async def mock_start():
pass
mock_consumer.start = mock_start
mock_consumer_class.return_value = mock_consumer
# Create a task that will complete quickly
async def quick_task():
await config_receiver.config_loader()
# Run the task with a timeout to prevent hanging
try:
await asyncio.wait_for(quick_task(), timeout=0.1)
except asyncio.TimeoutError:
# This is expected since the method runs indefinitely
pass
# Verify Consumer was created with correct parameters
mock_consumer_class.assert_called_once()
call_args = mock_consumer_class.call_args
assert call_args[1]['backend'] == mock_backend
assert call_args[1]['subscriber'] == "gateway-test-uuid"
assert call_args[1]['handler'] == config_receiver.on_config
assert call_args[1]['start_of_messages'] is True
@patch('asyncio.create_task')
@pytest.mark.asyncio
async def test_start_creates_config_loader_task(self, mock_create_task):
"""Test start method creates config loader task"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock create_task to avoid actually creating tasks with real coroutines
mock_task = Mock()
mock_create_task.return_value = mock_task
await config_receiver.start()
# Verify task was created
mock_create_task.assert_called_once()
# Verify the argument passed to create_task is a coroutine
call_args = mock_create_task.call_args[0]
assert len(call_args) == 1 # Should have one argument (the coroutine)
@pytest.mark.asyncio
async def test_on_config_mixed_flow_operations(self):
"""Test on_config with mixed add/remove operations"""
async def test_fetch_and_apply_mixed_flow_operations(self):
"""Test fetch_and_apply with mixed add/remove operations"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows
# Pre-populate
config_receiver.flows = {
"flow1": {"name": "test_flow_1", "steps": []},
"flow2": {"name": "test_flow_2", "steps": []}
"flow1": {"name": "test_flow_1"},
"flow2": {"name": "test_flow_2"}
}
# Track calls manually instead of using Mock
start_flow_calls = []
stop_flow_calls = []
async def mock_start_flow(*args):
start_flow_calls.append(args)
async def mock_stop_flow(*args):
stop_flow_calls.append(args)
# Directly assign to avoid patch.object detecting async methods
original_start_flow = config_receiver.start_flow
original_stop_flow = config_receiver.stop_flow
# Config removes flow1, keeps flow2, adds flow3
mock_resp = Mock()
mock_resp.error = None
mock_resp.version = 5
mock_resp.config = {
"flow": {
"flow2": '{"name": "test_flow_2"}',
"flow3": '{"name": "test_flow_3"}'
}
}
mock_client = AsyncMock()
mock_client.request.return_value = mock_resp
config_receiver.config_client = mock_client
start_calls = []
stop_calls = []
async def mock_start_flow(id, flow):
start_calls.append((id, flow))
async def mock_stop_flow(id, flow):
stop_calls.append((id, flow))
config_receiver.start_flow = mock_start_flow
config_receiver.stop_flow = mock_stop_flow
try:
# Create mock message with flow1 removed and flow3 added
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow2": '{"name": "test_flow_2", "steps": []}',
"flow3": '{"name": "test_flow_3", "steps": []}'
}
}
)
await config_receiver.on_config(mock_msg, None, None)
# Verify final state
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
# Verify operations
assert len(start_flow_calls) == 1
assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []})
assert len(stop_flow_calls) == 1
assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []})
finally:
# Restore original methods
config_receiver.start_flow = original_start_flow
config_receiver.stop_flow = original_stop_flow
@pytest.mark.asyncio
async def test_on_config_invalid_json_flow_data(self):
"""Test on_config handles invalid JSON in flow data"""
mock_backend = Mock()
config_receiver = ConfigReceiver(mock_backend)
# Mock the start_flow method with an async function
async def mock_start_flow(*args):
pass
config_receiver.start_flow = mock_start_flow
# Create mock message with invalid JSON
mock_msg = Mock()
mock_msg.value.return_value = Mock(
version="1.0",
config={
"flow": {
"flow1": '{"invalid": json}', # Invalid JSON
"flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON
}
}
)
# This should handle the exception gracefully
await config_receiver.on_config(mock_msg, None, None)
# The entire operation should fail due to JSON parsing error
# So no flows should be added
assert config_receiver.flows == {}
await config_receiver.fetch_and_apply()
assert "flow1" not in config_receiver.flows
assert "flow2" in config_receiver.flows
assert "flow3" in config_receiver.flows
assert len(start_calls) == 1
assert start_calls[0][0] == "flow3"
assert len(stop_calls) == 1
assert stop_calls[0][0] == "flow1"

View file

@ -49,7 +49,7 @@ class TestConfigRequestor:
mock_translator_registry.get_response_translator.return_value = Mock()
# Setup translator response
mock_request_translator.to_pulsar.return_value = "translated_request"
mock_request_translator.decode.return_value = "translated_request"
# Patch ServiceRequestor async methods with regular mocks (not AsyncMock)
with patch.object(ServiceRequestor, 'start', return_value=None), \
@ -64,7 +64,7 @@ class TestConfigRequestor:
result = requestor.to_request({"test": "body"})
# Verify translator was called correctly
mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"})
mock_request_translator.decode.assert_called_once_with({"test": "body"})
assert result == "translated_request"
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
@ -76,7 +76,7 @@ class TestConfigRequestor:
mock_translator_registry.get_response_translator.return_value = mock_response_translator
# Setup translator response
mock_response_translator.from_response_with_completion.return_value = "translated_response"
mock_response_translator.encode_with_completion.return_value = "translated_response"
requestor = ConfigRequestor(
backend=Mock(),
@ -89,5 +89,5 @@ class TestConfigRequestor:
result = requestor.from_response(mock_message)
# Verify translator was called correctly
mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message)
mock_response_translator.encode_with_completion.assert_called_once_with(mock_message)
assert result == "translated_response"

View file

@ -0,0 +1,359 @@
"""
Tests for inline explainability triples in response translators
and ProvenanceEvent parsing.
"""
import pytest
from trustgraph.schema import (
GraphRagResponse, DocumentRagResponse, AgentResponse,
Term, Triple, IRI, LITERAL, Error,
)
from trustgraph.messaging.translators.retrieval import (
GraphRagResponseTranslator,
DocumentRagResponseTranslator,
)
from trustgraph.messaging.translators.agent import (
AgentResponseTranslator,
)
from trustgraph.api.types import ProvenanceEvent
# --- Helpers ---
def make_triple(s_iri, p_iri, o_value, o_type=LITERAL):
"""Create a Triple with IRI subject/predicate and typed object."""
o = Term(type=IRI, iri=o_value) if o_type == IRI else Term(type=LITERAL, value=o_value)
return Triple(
s=Term(type=IRI, iri=s_iri),
p=Term(type=IRI, iri=p_iri),
o=o,
)
def sample_triples():
"""A few provenance triples for a question entity."""
return [
make_triple(
"urn:trustgraph:question:abc123",
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
"https://trustgraph.ai/ns/GraphRagQuestion",
o_type=IRI,
),
make_triple(
"urn:trustgraph:question:abc123",
"https://trustgraph.ai/ns/query",
"What is the internet?",
),
make_triple(
"urn:trustgraph:question:abc123",
"http://www.w3.org/ns/prov#startedAtTime",
"2026-04-07T09:00:00Z",
),
]
# --- GraphRag Translator ---
class TestGraphRagExplainTriples:
def test_explain_triples_encoded(self):
translator = GraphRagResponseTranslator()
triples = sample_triples()
response = GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=triples,
)
result = translator.encode(response)
assert "explain_triples" in result
assert len(result["explain_triples"]) == 3
# Check first triple is properly encoded
t = result["explain_triples"][0]
assert t["s"]["t"] == "i"
assert t["s"]["i"] == "urn:trustgraph:question:abc123"
assert t["p"]["t"] == "i"
def test_explain_triples_empty_not_included(self):
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
message_type="chunk",
response="Some answer text",
)
result = translator.encode(response)
assert "explain_triples" not in result
def test_explain_with_completion_returns_not_final(self):
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc123",
explain_triples=sample_triples(),
end_of_session=False,
)
result, is_final = translator.encode_with_completion(response)
assert is_final is False
def test_explain_id_and_graph_included(self):
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
message_type="explain",
explain_id="urn:trustgraph:question:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=sample_triples(),
)
result = translator.encode(response)
assert result["explain_id"] == "urn:trustgraph:question:abc123"
assert result["explain_graph"] == "urn:graph:retrieval"
# --- DocumentRag Translator ---
class TestDocumentRagExplainTriples:
def test_explain_triples_encoded(self):
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response=None,
message_type="explain",
explain_id="urn:trustgraph:docrag:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=sample_triples(),
)
result = translator.encode(response)
assert "explain_triples" in result
assert len(result["explain_triples"]) == 3
def test_explain_triples_empty_not_included(self):
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response="Answer text",
message_type="chunk",
)
result = translator.encode(response)
assert "explain_triples" not in result
# --- Agent Translator ---
class TestAgentExplainTriples:
def test_explain_triples_encoded(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
content="",
explain_id="urn:trustgraph:agent:session:abc123",
explain_graph="urn:graph:retrieval",
explain_triples=sample_triples(),
)
result = translator.encode(response)
assert "explain_triples" in result
assert len(result["explain_triples"]) == 3
t = result["explain_triples"][1]
assert t["p"]["i"] == "https://trustgraph.ai/ns/query"
assert t["o"]["t"] == "l"
assert t["o"]["v"] == "What is the internet?"
def test_explain_triples_empty_not_included(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="thought",
content="I need to think...",
)
result = translator.encode(response)
assert "explain_triples" not in result
def test_explain_with_completion_not_final(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
explain_id="urn:trustgraph:agent:session:abc123",
explain_triples=sample_triples(),
end_of_dialog=False,
)
result, is_final = translator.encode_with_completion(response)
assert is_final is False
def test_explain_with_completion_final(self):
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="answer",
content="The answer is...",
end_of_dialog=True,
)
result, is_final = translator.encode_with_completion(response)
assert is_final is True
# --- ProvenanceEvent ---
class TestProvenanceEvent:
def test_question_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:question:abc123",
)
assert event.event_type == "question"
def test_exploration_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:exploration:abc123",
)
assert event.event_type == "exploration"
def test_focus_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:focus:abc123",
)
assert event.event_type == "focus"
def test_synthesis_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:synthesis:abc123",
)
assert event.event_type == "synthesis"
def test_grounding_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:grounding:abc123",
)
assert event.event_type == "grounding"
def test_session_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:session:abc123",
)
assert event.event_type == "session"
def test_iteration_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:iteration:abc123:1",
)
assert event.event_type == "iteration"
def test_observation_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:observation:abc123:1",
)
assert event.event_type == "observation"
def test_conclusion_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:conclusion:abc123",
)
assert event.event_type == "conclusion"
def test_decomposition_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:decomposition:abc123",
)
assert event.event_type == "decomposition"
def test_finding_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:finding:abc123:0",
)
assert event.event_type == "finding"
def test_plan_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:plan:abc123",
)
assert event.event_type == "plan"
def test_step_result_event_type(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:agent:step-result:abc123:0",
)
assert event.event_type == "step-result"
def test_defaults(self):
event = ProvenanceEvent(
explain_id="urn:trustgraph:question:abc123",
)
assert event.entity is None
assert event.triples == []
assert event.explain_graph == ""
def test_with_triples(self):
raw = [{"s": {"t": "i", "i": "urn:x"}, "p": {"t": "i", "i": "urn:y"}, "o": {"t": "l", "v": "z"}}]
event = ProvenanceEvent(
explain_id="urn:trustgraph:question:abc123",
triples=raw,
)
assert len(event.triples) == 1
# --- Build ProvenanceEvent with entity parsing ---
class TestBuildProvenanceEvent:
def _make_client(self):
"""Create a minimal WebSocketClient-like object with _build_provenance_event."""
from trustgraph.api.socket_client import WebSocketClient
# We can't instantiate WebSocketClient easily, so test the method logic directly
return None
def test_entity_parsed_from_wire_triples(self):
"""Test that wire-format triples are parsed into an ExplainEntity."""
from trustgraph.api.explainability import ExplainEntity
wire_triples = [
{
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
"p": {"t": "i", "i": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"},
"o": {"t": "i", "i": "https://trustgraph.ai/ns/GraphRagQuestion"},
},
{
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/query"},
"o": {"t": "l", "v": "What is the internet?"},
},
]
# Parse triples the same way _build_provenance_event does
parsed = []
for t in wire_triples:
s = t.get("s", {}).get("i", "")
p = t.get("p", {}).get("i", "")
o_term = t.get("o", {})
if o_term.get("t") == "i":
o = o_term.get("i", "")
else:
o = o_term.get("v", "")
parsed.append((s, p, o))
entity = ExplainEntity.from_triples(
"urn:trustgraph:question:abc123", parsed
)
assert entity.entity_type == "question"
assert entity.query == "What is the internet?"
assert entity.question_type == "graph-rag"

View file

@ -25,7 +25,7 @@ from trustgraph.schema import (
class TestGraphRagResponseTranslator:
"""Test GraphRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
def test_encode_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = GraphRagResponseTranslator()
@ -36,14 +36,14 @@ class TestGraphRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert - Empty string should be included in result
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
def test_encode_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = GraphRagResponseTranslator()
@ -54,13 +54,13 @@ class TestGraphRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert result["response"] == "Some text"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_response(self):
def test_encode_with_none_response(self):
"""Test that None response is handled correctly"""
# Arrange
translator = GraphRagResponseTranslator()
@ -71,14 +71,14 @@ class TestGraphRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert - None should not be included
assert "response" not in result
assert result["end_of_stream"] is True
def test_from_response_with_completion_returns_correct_flag(self):
"""Test that from_response_with_completion returns correct is_final flag"""
def test_encode_with_completion_returns_correct_flag(self):
"""Test that encode_with_completion returns correct is_final flag"""
# Arrange
translator = GraphRagResponseTranslator()
@ -90,7 +90,7 @@ class TestGraphRagResponseTranslator:
)
# Act
result, is_final = translator.from_response_with_completion(response_chunk)
result, is_final = translator.encode_with_completion(response_chunk)
# Assert
assert is_final is False
@ -105,7 +105,7 @@ class TestGraphRagResponseTranslator:
)
# Act
result, is_final = translator.from_response_with_completion(final_response)
result, is_final = translator.encode_with_completion(final_response)
# Assert - is_final is based on end_of_session, not end_of_stream
assert is_final is True
@ -116,7 +116,7 @@ class TestGraphRagResponseTranslator:
class TestDocumentRagResponseTranslator:
"""Test DocumentRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
def test_encode_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = DocumentRagResponseTranslator()
@ -127,14 +127,14 @@ class TestDocumentRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
def test_encode_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = DocumentRagResponseTranslator()
@ -145,7 +145,7 @@ class TestDocumentRagResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert result["response"] == "Document content"
@ -155,7 +155,7 @@ class TestDocumentRagResponseTranslator:
class TestPromptResponseTranslator:
"""Test PromptResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_text(self):
def test_encode_with_empty_text(self):
"""Test that empty text strings are preserved"""
# Arrange
translator = PromptResponseTranslator()
@ -167,14 +167,14 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "text" in result
assert result["text"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_text(self):
def test_encode_with_non_empty_text(self):
"""Test that non-empty text works correctly"""
# Arrange
translator = PromptResponseTranslator()
@ -186,13 +186,13 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert result["text"] == "Some prompt response"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_text(self):
def test_encode_with_none_text(self):
"""Test that None text is handled correctly"""
# Arrange
translator = PromptResponseTranslator()
@ -204,14 +204,14 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "text" not in result
assert "object" in result
assert result["end_of_stream"] is True
def test_from_pulsar_includes_end_of_stream(self):
def test_encode_includes_end_of_stream(self):
"""Test that end_of_stream flag is always included"""
# Arrange
translator = PromptResponseTranslator()
@ -225,7 +225,7 @@ class TestPromptResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "end_of_stream" in result
@ -235,7 +235,7 @@ class TestPromptResponseTranslator:
class TestTextCompletionResponseTranslator:
"""Test TextCompletionResponseTranslator streaming behavior"""
def test_from_pulsar_always_includes_response(self):
def test_encode_always_includes_response(self):
"""Test that response field is always included, even if empty"""
# Arrange
translator = TextCompletionResponseTranslator()
@ -249,13 +249,13 @@ class TestTextCompletionResponseTranslator:
)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert - Response should always be present
assert "response" in result
assert result["response"] == ""
def test_from_response_with_completion_with_empty_final(self):
def test_encode_with_completion_with_empty_final(self):
"""Test that empty final response is handled correctly"""
# Arrange
translator = TextCompletionResponseTranslator()
@ -269,7 +269,7 @@ class TestTextCompletionResponseTranslator:
)
# Act
result, is_final = translator.from_response_with_completion(response)
result, is_final = translator.encode_with_completion(response)
# Assert
assert is_final is True
@ -297,7 +297,7 @@ class TestStreamingProtocolCompliance:
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
@ -320,7 +320,7 @@ class TestStreamingProtocolCompliance:
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
result = translator.encode(response)
# Assert
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"

View file

@ -0,0 +1,54 @@
"""
Unit tests for text document gateway translation compatibility.
"""
import base64
from trustgraph.messaging.translators.document_loading import TextDocumentTranslator
class TestTextDocumentTranslator:
def test_decode_decodes_base64_text(self):
translator = TextDocumentTranslator()
payload = "Cancer survival: 2.74× higher hazard ratio"
msg = translator.decode(
{
"id": "doc-1",
"user": "alice",
"collection": "research",
"charset": "utf-8",
"text": base64.b64encode(payload.encode("utf-8")).decode("ascii"),
}
)
assert msg.metadata.id == "doc-1"
assert msg.metadata.user == "alice"
assert msg.metadata.collection == "research"
assert msg.text == payload.encode("utf-8")
def test_decode_accepts_raw_utf8_text(self):
translator = TextDocumentTranslator()
payload = "Cancer survival: 2.74× higher hazard ratio"
msg = translator.decode(
{
"charset": "utf-8",
"text": payload,
}
)
assert msg.text == payload.encode("utf-8")
def test_decode_falls_back_to_raw_non_base64_ascii(self):
translator = TextDocumentTranslator()
payload = "plain-text payload"
msg = translator.decode(
{
"charset": "utf-8",
"text": payload,
}
)
assert msg.text == payload.encode("utf-8")

View file

@ -10,16 +10,19 @@ from trustgraph.schema import Triple, Term, IRI, LITERAL
from trustgraph.provenance.agent import (
agent_session_triples,
agent_iteration_triples,
agent_observation_triples,
agent_final_triples,
agent_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_TOOL_USE, TG_SYNTHESIS,
TG_AGENT_QUESTION,
)
@ -63,7 +66,7 @@ class TestAgentSessionTriples:
triples = agent_session_triples(
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
)
assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
assert has_type(triples, self.SESSION_URI, PROV_ENTITY)
assert has_type(triples, self.SESSION_URI, TG_QUESTION)
assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
@ -103,6 +106,25 @@ class TestAgentSessionTriples:
)
assert len(triples) == 6
def test_session_parent_uri(self):
"""Subagent sessions derive from a parent entity (e.g. Decomposition)."""
parent = "urn:trustgraph:agent:parent/decompose"
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z",
parent_uri=parent,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
assert derived is not None
assert derived.o.iri == parent
def test_session_no_parent_uri(self):
"""Top-level sessions have no wasDerivedFrom."""
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
assert derived is None
# ---------------------------------------------------------------------------
# agent_iteration_triples
@ -121,19 +143,17 @@ class TestAgentIterationTriples:
)
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
assert has_type(triples, self.ITER_URI, TG_TOOL_USE)
def test_first_iteration_generated_by_question(self):
"""First iteration uses wasGeneratedBy to link to question activity."""
def test_first_iteration_derived_from_question(self):
"""First iteration uses wasDerivedFrom to link to question entity."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
# Should NOT have wasDerivedFrom
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is None
assert derived is not None
assert derived.o.iri == self.SESSION_URI
def test_subsequent_iteration_derived_from_previous(self):
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
@ -144,9 +164,6 @@ class TestAgentIterationTriples:
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is not None
assert derived.o.iri == self.PREV_URI
# Should NOT have wasGeneratedBy
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is None
def test_iteration_label_includes_action(self):
triples = agent_iteration_triples(
@ -174,40 +191,24 @@ class TestAgentIterationTriples:
# Thought has correct types
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
# Thought was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Thought was derived from iteration
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, thought_uri)
assert derived is not None
assert derived.o.iri == self.ITER_URI
# Thought has document reference
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
assert doc is not None
assert doc.o.iri == thought_doc
def test_iteration_observation_sub_entity(self):
"""Observation is a sub-entity with Reflection and Observation types."""
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
obs_doc = "urn:doc:obs-1"
def test_iteration_no_observation_sub_entity(self):
"""Iteration no longer embeds observation — it's a separate entity."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
observation_uri=obs_uri,
observation_document_id=obs_doc,
)
# Iteration links to observation sub-entity
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_link is not None
assert obs_link.o.iri == obs_uri
# Observation has correct types
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
# Observation was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Observation has document reference
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
assert doc is not None
assert doc.o.iri == obs_doc
# No TG_OBSERVATION predicate on the iteration
for t in triples:
assert "observation" not in t.p.iri.lower() or "Observation" not in t.p.iri
def test_iteration_action_recorded(self):
triples = agent_iteration_triples(
@ -240,19 +241,17 @@ class TestAgentIterationTriples:
parsed = json.loads(arguments.o.value)
assert parsed == {}
def test_iteration_no_thought_or_observation(self):
"""Minimal iteration with just action — no thought or observation triples."""
def test_iteration_no_thought(self):
"""Minimal iteration with just action — no thought triples."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="noop",
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert thought is None
assert obs is None
def test_iteration_chaining(self):
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
"""Both first and second iterations use wasDerivedFrom."""
iter1_uri = "urn:trustgraph:agent:sess/i1"
iter2_uri = "urn:trustgraph:agent:sess/i2"
@ -263,13 +262,62 @@ class TestAgentIterationTriples:
iter2_uri, previous_uri=iter1_uri, action="step2",
)
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
assert gen1.o.iri == self.SESSION_URI
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
assert derived1.o.iri == self.SESSION_URI
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
assert derived2.o.iri == iter1_uri
# ---------------------------------------------------------------------------
# agent_observation_triples
# ---------------------------------------------------------------------------
class TestAgentObservationTriples:
OBS_URI = "urn:trustgraph:agent:test-session/i1/observation"
ITER_URI = "urn:trustgraph:agent:test-session/i1"
def test_observation_types(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
assert has_type(triples, self.OBS_URI, PROV_ENTITY)
assert has_type(triples, self.OBS_URI, TG_OBSERVATION_TYPE)
def test_observation_derived_from_iteration(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.OBS_URI)
assert derived is not None
assert derived.o.iri == self.ITER_URI
def test_observation_label(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
label = find_triple(triples, RDFS_LABEL, self.OBS_URI)
assert label is not None
assert label.o.value == "Observation"
def test_observation_document(self):
doc_id = "urn:doc:obs-1"
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI, document_id=doc_id,
)
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
assert doc is not None
assert doc.o.iri == doc_id
def test_observation_no_document(self):
triples = agent_observation_triples(
self.OBS_URI, self.ITER_URI,
)
doc = find_triple(triples, TG_DOCUMENT, self.OBS_URI)
assert doc is None
# ---------------------------------------------------------------------------
# agent_final_triples
# ---------------------------------------------------------------------------
@ -296,19 +344,15 @@ class TestAgentFinalTriples:
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is not None
assert derived.o.iri == self.PREV_URI
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is None
def test_final_generated_by_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasGeneratedBy."""
def test_final_derived_from_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasDerivedFrom to question."""
triples = agent_final_triples(
self.FINAL_URI, question_uri=self.SESSION_URI,
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is None
assert derived is not None
assert derived.o.iri == self.SESSION_URI
def test_final_label(self):
triples = agent_final_triples(
@ -334,3 +378,59 @@ class TestAgentFinalTriples:
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is None
# ---------------------------------------------------------------------------
# agent_synthesis_triples
# ---------------------------------------------------------------------------
class TestAgentSynthesisTriples:
SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis"
FINDING_0 = "urn:trustgraph:agent:test-session/finding/0"
FINDING_1 = "urn:trustgraph:agent:test-session/finding/1"
FINDING_2 = "urn:trustgraph:agent:test-session/finding/2"
def test_synthesis_types(self):
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
assert has_type(triples, self.SYNTH_URI, PROV_ENTITY)
assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS)
assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE)
def test_synthesis_single_parent_string(self):
"""Single parent passed as string."""
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 1
assert derived[0].o.iri == self.FINDING_0
def test_synthesis_multiple_parents(self):
"""Multiple parents for supervisor fan-in."""
parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2]
triples = agent_synthesis_triples(self.SYNTH_URI, parents)
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 3
derived_uris = {t.o.iri for t in derived}
assert derived_uris == set(parents)
def test_synthesis_single_parent_as_list(self):
"""Single parent passed as list."""
triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0])
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 1
assert derived[0].o.iri == self.FINDING_0
def test_synthesis_document(self):
triples = agent_synthesis_triples(
self.SYNTH_URI, self.FINDING_0,
document_id="urn:doc:synth",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI)
assert doc is not None
assert doc.o.iri == "urn:doc:synth"
def test_synthesis_label(self):
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI)
assert label is not None
assert label.o.value == "Synthesis"

View file

@ -16,6 +16,7 @@ from trustgraph.api.explainability import (
Synthesis,
Reflection,
Analysis,
Observation,
Conclusion,
parse_edge_selection_triples,
extract_term_value,
@ -23,12 +24,12 @@ from trustgraph.api.explainability import (
ExplainabilityClient,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM,
RDF_TYPE, RDFS_LABEL,
)
@ -180,14 +181,30 @@ class TestExplainEntityFromTriples:
("urn:ana:1", TG_ACTION, "graph-rag-query"),
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:1", triples)
assert isinstance(entity, Analysis)
assert entity.action == "graph-rag-query"
assert entity.arguments == '{"query": "test"}'
assert entity.thought == "urn:ref:thought-1"
assert entity.observation == "urn:ref:obs-1"
def test_observation(self):
triples = [
("urn:obs:1", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:obs:1", TG_DOCUMENT, "urn:doc:obs-content"),
]
entity = ExplainEntity.from_triples("urn:obs:1", triples)
assert isinstance(entity, Observation)
assert entity.document == "urn:doc:obs-content"
assert entity.entity_type == "observation"
def test_observation_no_document(self):
triples = [
("urn:obs:2", RDF_TYPE, TG_OBSERVATION_TYPE),
]
entity = ExplainEntity.from_triples("urn:obs:2", triples)
assert isinstance(entity, Observation)
assert entity.document == ""
def test_conclusion_with_document(self):
triples = [
@ -541,3 +558,96 @@ class TestExplainabilityClientDetectSessionType:
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
class TestChainWalkerFollowsSubTraceTerminal:
"""Test that _follow_provenance_chain continues from a sub-trace's
Synthesis to find downstream entities like Observation."""
def test_observation_found_via_subtrace_synthesis(self):
"""
DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation
The walker should find Analysis, the sub-trace, then follow from
Synthesis to discover Observation.
"""
# Entity triples (s, p, o)
entity_data = {
"urn:agent:q": [
("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION),
("urn:agent:q", TG_QUERY, "test"),
],
"urn:agent:analysis": [
("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS),
("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"),
],
"urn:graphrag:q": [
("urn:graphrag:q", RDF_TYPE, TG_QUESTION),
("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:graphrag:q", TG_QUERY, "test"),
("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"),
],
"urn:graphrag:synth": [
("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS),
("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"),
],
"urn:agent:obs": [
("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"),
],
"urn:agent:conclusion": [
("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION),
("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"),
],
}
# Build a mock flow that answers triples queries
# Query by s= returns that entity's triples
# Query by p=wasDerivedFrom, o=X returns entities derived from X
def mock_triples_query(s=None, p=None, o=None, **kwargs):
if s and not p:
# Fetch entity triples
tuples = entity_data.get(s, [])
return _make_wire_triples(tuples)
elif p == PROV_WAS_DERIVED_FROM and o:
# Find entities derived from o
results = []
for uri, tuples in entity_data.items():
for _, pred, obj in tuples:
if pred == PROV_WAS_DERIVED_FROM and obj == o:
results.append((uri, pred, obj))
return _make_wire_triples(results)
return []
mock_flow = MagicMock()
mock_flow.triples_query.side_effect = mock_triples_query
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
# Mock fetch_graphrag_trace to return a trace with a synthesis
synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis")
client.fetch_graphrag_trace = MagicMock(return_value={
"question": Question(uri="urn:graphrag:q", entity_type="question",
question_type="graph-rag"),
"synthesis": synth_entity,
})
trace = client.fetch_agent_trace(
"urn:agent:q",
graph="urn:graph:retrieval",
)
# Should have found all steps
step_types = [
type(s).__name__ if not isinstance(s, dict) else s.get("type")
for s in trace["steps"]
]
assert "Analysis" in step_types, f"Missing Analysis in {step_types}"
assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}"
assert "Observation" in step_types, f"Missing Observation in {step_types}"
assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}"
# Observation should come after the sub-trace
subtrace_idx = step_types.index("sub-trace")
obs_idx = step_types.index("Observation")
assert obs_idx > subtrace_idx, "Observation should appear after sub-trace"

View file

@ -500,7 +500,7 @@ class TestQuestionTriples:
def test_question_types(self):
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, PROV_ENTITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
@ -543,11 +543,11 @@ class TestGroundingTriples:
assert has_type(triples, self.GND_URI, PROV_ENTITY)
assert has_type(triples, self.GND_URI, TG_GROUNDING)
def test_grounding_generated_by_question(self):
def test_grounding_derived_from_question(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
assert gen is not None
assert gen.o.iri == self.Q_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.GND_URI)
assert derived is not None
assert derived.o.iri == self.Q_URI
def test_grounding_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
@ -730,7 +730,7 @@ class TestDocRagQuestionTriples:
def test_docrag_question_types(self):
triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, PROV_ENTITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)

View file

@ -0,0 +1,164 @@
"""
Tests for queue naming and topic mapping.
"""
import pytest
import argparse
from trustgraph.schema.core.topic import queue
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
from trustgraph.base.pulsar_backend import PulsarBackend
class TestQueueFunction:
def test_flow_default(self):
assert queue('text-completion-request') == 'flow:tg:text-completion-request'
def test_request_class(self):
assert queue('config', cls='request') == 'request:tg:config'
def test_response_class(self):
assert queue('config', cls='response') == 'response:tg:config'
def test_state_class(self):
assert queue('config', cls='state') == 'state:tg:config'
def test_custom_topicspace(self):
assert queue('config', cls='request', topicspace='prod') == 'request:prod:config'
def test_default_class_is_flow(self):
result = queue('something')
assert result.startswith('flow:')
class TestPulsarMapTopic:
@pytest.fixture
def backend(self):
"""Create a PulsarBackend without connecting."""
b = object.__new__(PulsarBackend)
return b
def test_flow_maps_to_persistent(self, backend):
assert backend.map_topic('flow:tg:text-completion-request') == \
'persistent://tg/flow/text-completion-request'
def test_state_maps_to_persistent(self, backend):
assert backend.map_topic('state:tg:config') == \
'persistent://tg/state/config'
def test_request_maps_to_non_persistent(self, backend):
assert backend.map_topic('request:tg:config') == \
'non-persistent://tg/request/config'
def test_response_maps_to_non_persistent(self, backend):
assert backend.map_topic('response:tg:librarian') == \
'non-persistent://tg/response/librarian'
def test_passthrough_pulsar_uri(self, backend):
uri = 'persistent://tg/flow/something'
assert backend.map_topic(uri) == uri
def test_invalid_format_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue format"):
backend.map_topic('bad-format')
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue class"):
backend.map_topic('unknown:tg:topic')
def test_custom_topicspace(self, backend):
assert backend.map_topic('flow:prod:my-queue') == \
'persistent://prod/flow/my-queue'
class TestGetPubsubDispatch:
def test_unknown_backend_raises(self):
with pytest.raises(ValueError, match="Unknown pub/sub backend"):
get_pubsub(pubsub_backend='redis')
class TestAddPubsubArgs:
def test_standalone_defaults_to_localhost(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args([])
assert args.pulsar_host == 'pulsar://localhost:6650'
assert args.pulsar_listener == 'localhost'
def test_non_standalone_defaults_to_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=False)
args = parser.parse_args([])
assert 'pulsar:6650' in args.pulsar_host
assert args.pulsar_listener is None
def test_cli_override_respected(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args(['--pulsar-host', 'pulsar://custom:6650'])
assert args.pulsar_host == 'pulsar://custom:6650'
def test_pubsub_backend_default(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.pubsub_backend == 'pulsar'
class TestAddPubsubArgsRabbitMQ:
def test_rabbitmq_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'rabbitmq',
'--rabbitmq-host', 'myhost',
'--rabbitmq-port', '5673',
])
assert args.pubsub_backend == 'rabbitmq'
assert args.rabbitmq_host == 'myhost'
assert args.rabbitmq_port == 5673
def test_rabbitmq_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.rabbitmq_host == 'rabbitmq'
assert args.rabbitmq_port == 5672
assert args.rabbitmq_username == 'guest'
assert args.rabbitmq_password == 'guest'
assert args.rabbitmq_vhost == '/'
def test_rabbitmq_standalone_defaults_to_localhost(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args([])
assert args.rabbitmq_host == 'localhost'
class TestQueueDefinitions:
"""Verify the actual queue constants produce correct names."""
def test_config_request(self):
from trustgraph.schema.services.config import config_request_queue
assert config_request_queue == 'request:tg:config'
def test_config_response(self):
from trustgraph.schema.services.config import config_response_queue
assert config_response_queue == 'response:tg:config'
def test_config_push(self):
from trustgraph.schema.services.config import config_push_queue
assert config_push_queue == 'flow:tg:config'
def test_librarian_request(self):
from trustgraph.schema.services.library import librarian_request_queue
assert librarian_request_queue == 'request:tg:librarian'
def test_knowledge_request(self):
from trustgraph.schema.knowledge.knowledge import knowledge_request_queue
assert knowledge_request_queue == 'request:tg:knowledge'

View file

@ -0,0 +1,107 @@
"""
Unit tests for RabbitMQ backend queue name mapping and factory dispatch.
Does not require a running RabbitMQ instance.
"""
import pytest
import argparse
pika = pytest.importorskip("pika", reason="pika not installed")
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
class TestRabbitMQMapQueueName:
@pytest.fixture
def backend(self):
b = object.__new__(RabbitMQBackend)
return b
def test_flow_is_durable(self, backend):
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
assert durable is True
assert name == 'tg.flow.text-completion-request'
def test_state_is_durable(self, backend):
name, durable = backend.map_queue_name('state:tg:config')
assert durable is True
assert name == 'tg.state.config'
def test_request_is_not_durable(self, backend):
name, durable = backend.map_queue_name('request:tg:config')
assert durable is False
assert name == 'tg.request.config'
def test_response_is_not_durable(self, backend):
name, durable = backend.map_queue_name('response:tg:librarian')
assert durable is False
assert name == 'tg.response.librarian'
def test_custom_topicspace(self, backend):
name, durable = backend.map_queue_name('flow:prod:my-queue')
assert name == 'prod.flow.my-queue'
assert durable is True
def test_no_colon_defaults_to_flow(self, backend):
name, durable = backend.map_queue_name('simple-queue')
assert name == 'tg.simple-queue'
assert durable is False
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue class"):
backend.map_queue_name('unknown:tg:topic')
def test_flow_with_flow_suffix(self, backend):
"""Queue names with flow suffix (e.g. :default) are preserved."""
name, durable = backend.map_queue_name('request:tg:prompt:default')
assert name == 'tg.request.prompt:default'
class TestGetPubsubRabbitMQ:
def test_factory_creates_rabbitmq_backend(self):
backend = get_pubsub(pubsub_backend='rabbitmq')
assert isinstance(backend, RabbitMQBackend)
def test_factory_passes_config(self):
backend = get_pubsub(
pubsub_backend='rabbitmq',
rabbitmq_host='myhost',
rabbitmq_port=5673,
rabbitmq_username='user',
rabbitmq_password='pass',
rabbitmq_vhost='/test',
)
assert isinstance(backend, RabbitMQBackend)
# Verify connection params were set
params = backend._connection_params
assert params.host == 'myhost'
assert params.port == 5673
assert params.virtual_host == '/test'
class TestAddPubsubArgsRabbitMQ:
def test_rabbitmq_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'rabbitmq',
'--rabbitmq-host', 'myhost',
'--rabbitmq-port', '5673',
])
assert args.pubsub_backend == 'rabbitmq'
assert args.rabbitmq_host == 'myhost'
assert args.rabbitmq_port == 5673
def test_rabbitmq_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.rabbitmq_host == 'rabbitmq'
assert args.rabbitmq_port == 5672
assert args.rabbitmq_username == 'guest'
assert args.rabbitmq_password == 'guest'
assert args.rabbitmq_vhost == '/'

View file

@ -0,0 +1,424 @@
"""
Tests for SPARQL FILTER expression evaluator.
"""
import pytest
from trustgraph.schema import Term, IRI, LITERAL, BLANK
from trustgraph.query.sparql.expressions import (
evaluate_expression, _effective_boolean, _to_string, _to_numeric,
_comparable_value,
)
# --- Helpers ---
def iri(v):
return Term(type=IRI, iri=v)
def lit(v, datatype="", language=""):
return Term(type=LITERAL, value=v, datatype=datatype, language=language)
def blank(v):
return Term(type=BLANK, id=v)
XSD = "http://www.w3.org/2001/XMLSchema#"
class TestEvaluateExpression:
"""Test expression evaluation with rdflib algebra nodes."""
def test_variable_bound(self):
from rdflib.term import Variable
result = evaluate_expression(Variable("x"), {"x": lit("hello")})
assert result.value == "hello"
def test_variable_unbound(self):
from rdflib.term import Variable
result = evaluate_expression(Variable("x"), {})
assert result is None
def test_uriref_constant(self):
from rdflib import URIRef
result = evaluate_expression(
URIRef("http://example.com/a"), {}
)
assert result.type == IRI
assert result.iri == "http://example.com/a"
def test_literal_constant(self):
from rdflib import Literal
result = evaluate_expression(Literal("hello"), {})
assert result.type == LITERAL
assert result.value == "hello"
def test_boolean_constant(self):
assert evaluate_expression(True, {}) is True
assert evaluate_expression(False, {}) is False
def test_numeric_constant(self):
assert evaluate_expression(42, {}) == 42
assert evaluate_expression(3.14, {}) == 3.14
def test_none_returns_true(self):
assert evaluate_expression(None, {}) is True
class TestRelationalExpressions:
"""Test comparison operators via CompValue nodes."""
def _make_relational(self, left, op, right):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("RelationalExpression",
expr=left, op=op, other=right)
def test_equal_literals(self):
from rdflib import Literal
expr = self._make_relational(Literal("a"), "=", Literal("a"))
assert evaluate_expression(expr, {}) is True
def test_not_equal_literals(self):
from rdflib import Literal
expr = self._make_relational(Literal("a"), "!=", Literal("b"))
assert evaluate_expression(expr, {}) is True
def test_less_than(self):
from rdflib import Literal
expr = self._make_relational(Literal("a"), "<", Literal("b"))
assert evaluate_expression(expr, {}) is True
def test_greater_than(self):
from rdflib import Literal
expr = self._make_relational(Literal("b"), ">", Literal("a"))
assert evaluate_expression(expr, {}) is True
def test_equal_with_variables(self):
from rdflib.term import Variable
expr = self._make_relational(Variable("x"), "=", Variable("y"))
sol = {"x": lit("same"), "y": lit("same")}
assert evaluate_expression(expr, sol) is True
def test_unequal_with_variables(self):
from rdflib.term import Variable
expr = self._make_relational(Variable("x"), "=", Variable("y"))
sol = {"x": lit("one"), "y": lit("two")}
assert evaluate_expression(expr, sol) is False
def test_none_operand_returns_false(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_relational(Variable("x"), "=", Literal("a"))
assert evaluate_expression(expr, {}) is False
class TestLogicalExpressions:
def _make_and(self, exprs):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("ConditionalAndExpression",
expr=exprs[0], other=exprs[1:])
def _make_or(self, exprs):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("ConditionalOrExpression",
expr=exprs[0], other=exprs[1:])
def _make_not(self, expr):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue("UnaryNot", expr=expr)
def test_and_true_true(self):
result = evaluate_expression(self._make_and([True, True]), {})
assert result is True
def test_and_true_false(self):
result = evaluate_expression(self._make_and([True, False]), {})
assert result is False
def test_or_false_true(self):
result = evaluate_expression(self._make_or([False, True]), {})
assert result is True
def test_or_false_false(self):
result = evaluate_expression(self._make_or([False, False]), {})
assert result is False
def test_not_true(self):
result = evaluate_expression(self._make_not(True), {})
assert result is False
def test_not_false(self):
result = evaluate_expression(self._make_not(False), {})
assert result is True
class TestBuiltinFunctions:
def _make_builtin(self, name, **kwargs):
from rdflib.plugins.sparql.parserutils import CompValue
return CompValue(f"Builtin_{name}", **kwargs)
def test_bound_true(self):
from rdflib.term import Variable
expr = self._make_builtin("BOUND", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("hi")}) is True
def test_bound_false(self):
from rdflib.term import Variable
expr = self._make_builtin("BOUND", arg=Variable("x"))
assert evaluate_expression(expr, {}) is False
def test_isiri_true(self):
from rdflib.term import Variable
expr = self._make_builtin("isIRI", arg=Variable("x"))
assert evaluate_expression(expr, {"x": iri("http://x")}) is True
def test_isiri_false(self):
from rdflib.term import Variable
expr = self._make_builtin("isIRI", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("hello")}) is False
def test_isliteral_true(self):
from rdflib.term import Variable
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_isliteral_false(self):
from rdflib.term import Variable
expr = self._make_builtin("isLITERAL", arg=Variable("x"))
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
def test_isblank_true(self):
from rdflib.term import Variable
expr = self._make_builtin("isBLANK", arg=Variable("x"))
assert evaluate_expression(expr, {"x": blank("b1")}) is True
def test_isblank_false(self):
from rdflib.term import Variable
expr = self._make_builtin("isBLANK", arg=Variable("x"))
assert evaluate_expression(expr, {"x": iri("http://x")}) is False
def test_str(self):
from rdflib.term import Variable
expr = self._make_builtin("STR", arg=Variable("x"))
result = evaluate_expression(expr, {"x": iri("http://example.com/a")})
assert result.type == LITERAL
assert result.value == "http://example.com/a"
def test_lang(self):
from rdflib.term import Variable
expr = self._make_builtin("LANG", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("hello", language="en")}
)
assert result.value == "en"
def test_lang_no_tag(self):
from rdflib.term import Variable
expr = self._make_builtin("LANG", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == ""
def test_datatype(self):
from rdflib.term import Variable
expr = self._make_builtin("DATATYPE", arg=Variable("x"))
result = evaluate_expression(
expr, {"x": lit("42", datatype=XSD + "integer")}
)
assert result.type == IRI
assert result.iri == XSD + "integer"
def test_strlen(self):
from rdflib.term import Variable
expr = self._make_builtin("STRLEN", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result == 5
def test_ucase(self):
from rdflib.term import Variable
expr = self._make_builtin("UCASE", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("hello")})
assert result.value == "HELLO"
def test_lcase(self):
from rdflib.term import Variable
expr = self._make_builtin("LCASE", arg=Variable("x"))
result = evaluate_expression(expr, {"x": lit("HELLO")})
assert result.value == "hello"
def test_contains_true(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("CONTAINS",
arg1=Variable("x"), arg2=Literal("ell"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_contains_false(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("CONTAINS",
arg1=Variable("x"), arg2=Literal("xyz"))
assert evaluate_expression(expr, {"x": lit("hello")}) is False
def test_strstarts_true(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRSTARTS",
arg1=Variable("x"), arg2=Literal("hel"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_strends_true(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("STRENDS",
arg1=Variable("x"), arg2=Literal("llo"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_regex_match(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REGEX",
text=Variable("x"),
pattern=Literal("^hel"),
flags=None)
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_regex_case_insensitive(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REGEX",
text=Variable("x"),
pattern=Literal("HELLO"),
flags=Literal("i"))
assert evaluate_expression(expr, {"x": lit("hello")}) is True
def test_regex_no_match(self):
from rdflib.term import Variable
from rdflib import Literal
expr = self._make_builtin("REGEX",
text=Variable("x"),
pattern=Literal("^world"),
flags=None)
assert evaluate_expression(expr, {"x": lit("hello")}) is False
class TestEffectiveBoolean:
def test_true(self):
assert _effective_boolean(True) is True
def test_false(self):
assert _effective_boolean(False) is False
def test_none(self):
assert _effective_boolean(None) is False
def test_nonzero_int(self):
assert _effective_boolean(42) is True
def test_zero_int(self):
assert _effective_boolean(0) is False
def test_nonempty_string(self):
assert _effective_boolean("hello") is True
def test_empty_string(self):
assert _effective_boolean("") is False
def test_iri_term(self):
assert _effective_boolean(iri("http://x")) is True
def test_nonempty_literal(self):
assert _effective_boolean(lit("hello")) is True
def test_empty_literal(self):
assert _effective_boolean(lit("")) is False
def test_boolean_literal_true(self):
assert _effective_boolean(
lit("true", datatype=XSD + "boolean")
) is True
def test_boolean_literal_false(self):
assert _effective_boolean(
lit("false", datatype=XSD + "boolean")
) is False
def test_numeric_literal_nonzero(self):
assert _effective_boolean(
lit("42", datatype=XSD + "integer")
) is True
def test_numeric_literal_zero(self):
assert _effective_boolean(
lit("0", datatype=XSD + "integer")
) is False
class TestToString:
def test_none(self):
assert _to_string(None) == ""
def test_string(self):
assert _to_string("hello") == "hello"
def test_iri_term(self):
assert _to_string(iri("http://example.com")) == "http://example.com"
def test_literal_term(self):
assert _to_string(lit("hello")) == "hello"
def test_blank_term(self):
assert _to_string(blank("b1")) == "b1"
class TestToNumeric:
def test_none(self):
assert _to_numeric(None) is None
def test_int(self):
assert _to_numeric(42) == 42
def test_float(self):
assert _to_numeric(3.14) == 3.14
def test_integer_literal(self):
assert _to_numeric(lit("42")) == 42
def test_decimal_literal(self):
assert _to_numeric(lit("3.14")) == 3.14
def test_non_numeric_literal(self):
assert _to_numeric(lit("hello")) is None
def test_numeric_string(self):
assert _to_numeric("42") == 42
def test_non_numeric_string(self):
assert _to_numeric("abc") is None
class TestComparableValue:
def test_none(self):
assert _comparable_value(None) == (0, "")
def test_int(self):
assert _comparable_value(42) == (2, 42)
def test_iri(self):
assert _comparable_value(iri("http://x")) == (4, "http://x")
def test_literal(self):
assert _comparable_value(lit("hello")) == (3, "hello")
def test_numeric_literal(self):
assert _comparable_value(lit("42")) == (2, 42)
def test_ordering(self):
vals = [lit("b"), lit("a"), lit("c")]
sorted_vals = sorted(vals, key=_comparable_value)
assert sorted_vals[0].value == "a"
assert sorted_vals[1].value == "b"
assert sorted_vals[2].value == "c"

View file

@ -0,0 +1,205 @@
"""
Tests for the SPARQL parser module.
"""
import pytest
from trustgraph.query.sparql.parser import (
parse_sparql, ParseError, rdflib_term_to_term, term_to_rdflib,
)
from trustgraph.schema import Term, IRI, LITERAL, BLANK
class TestParseSparql:
"""Tests for parse_sparql function."""
def test_select_query_type(self):
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
assert parsed.query_type == "select"
def test_select_variables(self):
parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
assert parsed.variables == ["s", "p", "o"]
def test_select_subset_variables(self):
parsed = parse_sparql("SELECT ?s ?o WHERE { ?s ?p ?o }")
assert parsed.variables == ["s", "o"]
def test_ask_query_type(self):
parsed = parse_sparql(
"ASK { <http://example.com/a> ?p ?o }"
)
assert parsed.query_type == "ask"
def test_ask_no_variables(self):
parsed = parse_sparql(
"ASK { <http://example.com/a> ?p ?o }"
)
assert parsed.variables == []
def test_construct_query_type(self):
parsed = parse_sparql(
"CONSTRUCT { ?s <http://example.com/knows> ?o } "
"WHERE { ?s <http://example.com/friendOf> ?o }"
)
assert parsed.query_type == "construct"
def test_describe_query_type(self):
parsed = parse_sparql(
"DESCRIBE <http://example.com/alice>"
)
assert parsed.query_type == "describe"
def test_select_with_limit(self):
parsed = parse_sparql(
"SELECT ?s WHERE { ?s ?p ?o } LIMIT 10"
)
assert parsed.query_type == "select"
assert parsed.variables == ["s"]
def test_select_with_distinct(self):
parsed = parse_sparql(
"SELECT DISTINCT ?s WHERE { ?s ?p ?o }"
)
assert parsed.query_type == "select"
assert parsed.variables == ["s"]
def test_select_with_filter(self):
parsed = parse_sparql(
'SELECT ?s ?label WHERE { '
' ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label . '
' FILTER(CONTAINS(STR(?label), "test")) '
'}'
)
assert parsed.query_type == "select"
assert parsed.variables == ["s", "label"]
def test_select_with_optional(self):
parsed = parse_sparql(
"SELECT ?s ?p ?o ?label WHERE { "
" ?s ?p ?o . "
" OPTIONAL { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
"}"
)
assert parsed.query_type == "select"
assert set(parsed.variables) == {"s", "p", "o", "label"}
def test_select_with_union(self):
parsed = parse_sparql(
"SELECT ?s ?label WHERE { "
" { ?s <http://example.com/name> ?label } "
" UNION "
" { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
"}"
)
assert parsed.query_type == "select"
def test_select_with_order_by(self):
parsed = parse_sparql(
"SELECT ?s ?label WHERE { ?s <http://www.w3.org/2000/01/rdf-schema#label> ?label } "
"ORDER BY ?label"
)
assert parsed.query_type == "select"
def test_select_with_group_by(self):
parsed = parse_sparql(
"SELECT ?p (COUNT(?o) AS ?count) WHERE { ?s ?p ?o } "
"GROUP BY ?p ORDER BY DESC(?count)"
)
assert parsed.query_type == "select"
def test_select_with_prefixes(self):
parsed = parse_sparql(
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
"SELECT ?s ?label WHERE { ?s rdfs:label ?label }"
)
assert parsed.query_type == "select"
assert parsed.variables == ["s", "label"]
def test_algebra_not_none(self):
parsed = parse_sparql("SELECT ?s WHERE { ?s ?p ?o }")
assert parsed.algebra is not None
def test_parse_error_invalid_sparql(self):
with pytest.raises(ParseError):
parse_sparql("NOT VALID SPARQL AT ALL")
def test_parse_error_incomplete_query(self):
with pytest.raises(ParseError):
parse_sparql("SELECT ?s WHERE {")
def test_parse_error_message(self):
with pytest.raises(ParseError, match="SPARQL parse error"):
parse_sparql("GIBBERISH")
class TestRdflibTermToTerm:
"""Tests for rdflib-to-Term conversion."""
def test_uriref_to_term(self):
from rdflib import URIRef
term = rdflib_term_to_term(URIRef("http://example.com/alice"))
assert term.type == IRI
assert term.iri == "http://example.com/alice"
def test_literal_to_term(self):
from rdflib import Literal
term = rdflib_term_to_term(Literal("hello"))
assert term.type == LITERAL
assert term.value == "hello"
def test_typed_literal_to_term(self):
from rdflib import Literal, URIRef
term = rdflib_term_to_term(
Literal("42", datatype=URIRef("http://www.w3.org/2001/XMLSchema#integer"))
)
assert term.type == LITERAL
assert term.value == "42"
assert term.datatype == "http://www.w3.org/2001/XMLSchema#integer"
def test_lang_literal_to_term(self):
from rdflib import Literal
term = rdflib_term_to_term(Literal("hello", lang="en"))
assert term.type == LITERAL
assert term.value == "hello"
assert term.language == "en"
def test_bnode_to_term(self):
from rdflib import BNode
term = rdflib_term_to_term(BNode("b1"))
assert term.type == BLANK
assert term.id == "b1"
class TestTermToRdflib:
"""Tests for Term-to-rdflib conversion."""
def test_iri_term_to_uriref(self):
from rdflib import URIRef
result = term_to_rdflib(Term(type=IRI, iri="http://example.com/x"))
assert isinstance(result, URIRef)
assert str(result) == "http://example.com/x"
def test_literal_term_to_literal(self):
from rdflib import Literal
result = term_to_rdflib(Term(type=LITERAL, value="hello"))
assert isinstance(result, Literal)
assert str(result) == "hello"
def test_typed_literal_roundtrip(self):
from rdflib import URIRef
original = Term(
type=LITERAL, value="42",
datatype="http://www.w3.org/2001/XMLSchema#integer"
)
rdflib_term = term_to_rdflib(original)
assert rdflib_term.datatype == URIRef("http://www.w3.org/2001/XMLSchema#integer")
def test_lang_literal_roundtrip(self):
original = Term(type=LITERAL, value="bonjour", language="fr")
rdflib_term = term_to_rdflib(original)
assert rdflib_term.language == "fr"
def test_blank_term_to_bnode(self):
from rdflib import BNode
result = term_to_rdflib(Term(type=BLANK, id="b1"))
assert isinstance(result, BNode)

View file

@ -0,0 +1,345 @@
"""
Tests for SPARQL solution sequence operations.
"""
import pytest
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.solutions import (
hash_join, left_join, union, project, distinct,
order_by, slice_solutions, _terms_equal, _compatible,
)
# --- Test helpers ---
def iri(v):
return Term(type=IRI, iri=v)
def lit(v):
return Term(type=LITERAL, value=v)
# --- Fixtures ---
@pytest.fixture
def alice():
return iri("http://example.com/alice")
@pytest.fixture
def bob():
return iri("http://example.com/bob")
@pytest.fixture
def carol():
return iri("http://example.com/carol")
@pytest.fixture
def knows():
return iri("http://example.com/knows")
@pytest.fixture
def name_alice():
return lit("Alice")
@pytest.fixture
def name_bob():
return lit("Bob")
class TestTermsEqual:
def test_equal_iris(self):
assert _terms_equal(iri("http://x.com/a"), iri("http://x.com/a"))
def test_unequal_iris(self):
assert not _terms_equal(iri("http://x.com/a"), iri("http://x.com/b"))
def test_equal_literals(self):
assert _terms_equal(lit("hello"), lit("hello"))
def test_unequal_literals(self):
assert not _terms_equal(lit("hello"), lit("world"))
def test_iri_vs_literal(self):
assert not _terms_equal(iri("hello"), lit("hello"))
def test_none_none(self):
assert _terms_equal(None, None)
def test_none_vs_term(self):
assert not _terms_equal(None, iri("http://x.com/a"))
class TestCompatible:
def test_no_shared_variables(self):
assert _compatible({"a": iri("http://x")}, {"b": iri("http://y")})
def test_shared_variable_same_value(self, alice):
assert _compatible({"s": alice, "x": lit("1")}, {"s": alice, "y": lit("2")})
def test_shared_variable_different_value(self, alice, bob):
assert not _compatible({"s": alice}, {"s": bob})
def test_empty_solutions(self):
assert _compatible({}, {})
def test_empty_vs_nonempty(self, alice):
assert _compatible({}, {"s": alice})
class TestHashJoin:
def test_join_on_shared_variable(self, alice, bob, name_alice, name_bob):
left = [
{"s": alice, "p": iri("http://example.com/knows"), "o": bob},
{"s": bob, "p": iri("http://example.com/knows"), "o": alice},
]
right = [
{"s": alice, "label": name_alice},
{"s": bob, "label": name_bob},
]
result = hash_join(left, right)
assert len(result) == 2
# Check that joined solutions have all variables
for sol in result:
assert "s" in sol
assert "p" in sol
assert "o" in sol
assert "label" in sol
def test_join_no_shared_variables_cross_product(self, alice, bob):
left = [{"a": alice}]
right = [{"b": bob}, {"b": alice}]
result = hash_join(left, right)
assert len(result) == 2
def test_join_no_matches(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob}]
result = hash_join(left, right)
assert len(result) == 0
def test_join_empty_left(self, alice):
result = hash_join([], [{"s": alice}])
assert len(result) == 0
def test_join_empty_right(self, alice):
result = hash_join([{"s": alice}], [])
assert len(result) == 0
def test_join_multiple_matches(self, alice, name_alice):
left = [
{"s": alice, "p": iri("http://e.com/a")},
{"s": alice, "p": iri("http://e.com/b")},
]
right = [{"s": alice, "label": name_alice}]
result = hash_join(left, right)
assert len(result) == 2
def test_join_preserves_values(self, alice, name_alice):
left = [{"s": alice, "x": lit("1")}]
right = [{"s": alice, "y": lit("2")}]
result = hash_join(left, right)
assert len(result) == 1
assert result[0]["x"].value == "1"
assert result[0]["y"].value == "2"
class TestLeftJoin:
def test_left_join_with_matches(self, alice, bob, name_alice):
left = [{"s": alice}, {"s": bob}]
right = [{"s": alice, "label": name_alice}]
result = left_join(left, right)
assert len(result) == 2
# Alice has label
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
assert len(alice_sols) == 1
assert "label" in alice_sols[0]
# Bob preserved without label
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
assert len(bob_sols) == 1
assert "label" not in bob_sols[0]
def test_left_join_no_matches(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob, "label": lit("Bob")}]
result = left_join(left, right)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/alice"
assert "label" not in result[0]
def test_left_join_empty_right(self, alice):
left = [{"s": alice}]
result = left_join(left, [])
assert len(result) == 1
def test_left_join_empty_left(self):
result = left_join([], [{"s": iri("http://x")}])
assert len(result) == 0
def test_left_join_with_filter(self, alice, bob):
left = [{"s": alice}, {"s": bob}]
right = [
{"s": alice, "val": lit("yes")},
{"s": bob, "val": lit("no")},
]
# Filter: only keep joins where val == "yes"
result = left_join(
left, right,
filter_fn=lambda sol: sol.get("val") and sol["val"].value == "yes"
)
assert len(result) == 2
# Alice matches filter
alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"]
assert "val" in alice_sols[0]
assert alice_sols[0]["val"].value == "yes"
# Bob doesn't match filter, preserved without val
bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"]
assert "val" not in bob_sols[0]
class TestUnion:
def test_union_concatenates(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob}]
result = union(left, right)
assert len(result) == 2
def test_union_preserves_order(self, alice, bob):
left = [{"s": alice}]
right = [{"s": bob}]
result = union(left, right)
assert result[0]["s"].iri == "http://example.com/alice"
assert result[1]["s"].iri == "http://example.com/bob"
def test_union_empty_left(self, alice):
result = union([], [{"s": alice}])
assert len(result) == 1
def test_union_both_empty(self):
result = union([], [])
assert len(result) == 0
def test_union_allows_duplicates(self, alice):
result = union([{"s": alice}], [{"s": alice}])
assert len(result) == 2
class TestProject:
def test_project_keeps_selected(self, alice, name_alice):
solutions = [{"s": alice, "label": name_alice, "extra": lit("x")}]
result = project(solutions, ["s", "label"])
assert len(result) == 1
assert "s" in result[0]
assert "label" in result[0]
assert "extra" not in result[0]
def test_project_missing_variable(self, alice):
solutions = [{"s": alice}]
result = project(solutions, ["s", "missing"])
assert len(result) == 1
assert "s" in result[0]
assert "missing" not in result[0]
def test_project_empty(self):
result = project([], ["s"])
assert len(result) == 0
class TestDistinct:
def test_removes_duplicates(self, alice):
solutions = [{"s": alice}, {"s": alice}, {"s": alice}]
result = distinct(solutions)
assert len(result) == 1
def test_keeps_different(self, alice, bob):
solutions = [{"s": alice}, {"s": bob}]
result = distinct(solutions)
assert len(result) == 2
def test_empty(self):
result = distinct([])
assert len(result) == 0
def test_multi_variable_distinct(self, alice, bob):
solutions = [
{"s": alice, "o": bob},
{"s": alice, "o": bob},
{"s": alice, "o": alice},
]
result = distinct(solutions)
assert len(result) == 2
class TestOrderBy:
def test_order_by_ascending(self):
solutions = [
{"label": lit("Charlie")},
{"label": lit("Alice")},
{"label": lit("Bob")},
]
key_fns = [(lambda sol: sol.get("label"), True)]
result = order_by(solutions, key_fns)
assert result[0]["label"].value == "Alice"
assert result[1]["label"].value == "Bob"
assert result[2]["label"].value == "Charlie"
def test_order_by_descending(self):
solutions = [
{"label": lit("Alice")},
{"label": lit("Charlie")},
{"label": lit("Bob")},
]
key_fns = [(lambda sol: sol.get("label"), False)]
result = order_by(solutions, key_fns)
assert result[0]["label"].value == "Charlie"
assert result[1]["label"].value == "Bob"
assert result[2]["label"].value == "Alice"
def test_order_by_empty(self):
result = order_by([], [(lambda sol: sol.get("x"), True)])
assert len(result) == 0
def test_order_by_no_keys(self, alice):
solutions = [{"s": alice}]
result = order_by(solutions, [])
assert len(result) == 1
class TestSlice:
def test_limit(self, alice, bob, carol):
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
result = slice_solutions(solutions, limit=2)
assert len(result) == 2
def test_offset(self, alice, bob, carol):
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
result = slice_solutions(solutions, offset=1)
assert len(result) == 2
assert result[0]["s"].iri == "http://example.com/bob"
def test_offset_and_limit(self, alice, bob, carol):
solutions = [{"s": alice}, {"s": bob}, {"s": carol}]
result = slice_solutions(solutions, offset=1, limit=1)
assert len(result) == 1
assert result[0]["s"].iri == "http://example.com/bob"
def test_limit_zero(self, alice):
result = slice_solutions([{"s": alice}], limit=0)
assert len(result) == 0
def test_offset_beyond_length(self, alice):
result = slice_solutions([{"s": alice}], offset=10)
assert len(result) == 0
def test_no_slice(self, alice, bob):
solutions = [{"s": alice}, {"s": bob}]
result = slice_solutions(solutions)
assert len(result) == 2

View file

@ -28,21 +28,21 @@ def triple_tx():
class TestTermTranslatorIri:
def test_iri_to_pulsar(self, term_tx):
def test_iri_decode(self, term_tx):
data = {"t": "i", "i": "http://example.org/Alice"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == IRI
assert term.iri == "http://example.org/Alice"
def test_iri_from_pulsar(self, term_tx):
def test_iri_encode(self, term_tx):
term = Term(type=IRI, iri="http://example.org/Bob")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire == {"t": "i", "i": "http://example.org/Bob"}
def test_iri_round_trip(self, term_tx):
original = Term(type=IRI, iri="http://example.org/round")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
@ -52,21 +52,21 @@ class TestTermTranslatorIri:
class TestTermTranslatorBlank:
def test_blank_to_pulsar(self, term_tx):
def test_blank_decode(self, term_tx):
data = {"t": "b", "d": "_:b42"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == BLANK
assert term.id == "_:b42"
def test_blank_from_pulsar(self, term_tx):
def test_blank_encode(self, term_tx):
term = Term(type=BLANK, id="_:node1")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire == {"t": "b", "d": "_:node1"}
def test_blank_round_trip(self, term_tx):
original = Term(type=BLANK, id="_:x")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
@ -76,29 +76,29 @@ class TestTermTranslatorBlank:
class TestTermTranslatorTypedLiteral:
def test_plain_literal_to_pulsar(self, term_tx):
def test_plain_literal_decode(self, term_tx):
data = {"t": "l", "v": "hello"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == LITERAL
assert term.value == "hello"
assert term.datatype == ""
assert term.language == ""
def test_xsd_integer_to_pulsar(self, term_tx):
def test_xsd_integer_decode(self, term_tx):
data = {
"t": "l", "v": "42",
"dt": "http://www.w3.org/2001/XMLSchema#integer",
}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.value == "42"
assert term.datatype.endswith("#integer")
def test_typed_literal_from_pulsar(self, term_tx):
def test_typed_literal_encode(self, term_tx):
term = Term(
type=LITERAL, value="3.14",
datatype="http://www.w3.org/2001/XMLSchema#double",
)
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire["t"] == "l"
assert wire["v"] == "3.14"
assert wire["dt"] == "http://www.w3.org/2001/XMLSchema#double"
@ -109,13 +109,13 @@ class TestTermTranslatorTypedLiteral:
type=LITERAL, value="true",
datatype="http://www.w3.org/2001/XMLSchema#boolean",
)
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
def test_plain_literal_omits_dt_and_ln(self, term_tx):
term = Term(type=LITERAL, value="x")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert "dt" not in wire
assert "ln" not in wire
@ -126,22 +126,22 @@ class TestTermTranslatorTypedLiteral:
class TestTermTranslatorLangLiteral:
def test_language_tag_to_pulsar(self, term_tx):
def test_language_tag_decode(self, term_tx):
data = {"t": "l", "v": "bonjour", "ln": "fr"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.value == "bonjour"
assert term.language == "fr"
def test_language_tag_from_pulsar(self, term_tx):
def test_language_tag_encode(self, term_tx):
term = Term(type=LITERAL, value="colour", language="en-GB")
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire["ln"] == "en-GB"
assert "dt" not in wire # No datatype
def test_language_tag_round_trip(self, term_tx):
original = Term(type=LITERAL, value="hola", language="es")
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored == original
@ -151,7 +151,7 @@ class TestTermTranslatorLangLiteral:
class TestTermTranslatorQuotedTriple:
def test_quoted_triple_to_pulsar(self, term_tx):
def test_quoted_triple_decode(self, term_tx):
data = {
"t": "t",
"tr": {
@ -160,20 +160,20 @@ class TestTermTranslatorQuotedTriple:
"o": {"t": "i", "i": "http://example.org/Bob"},
},
}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == TRIPLE
assert term.triple is not None
assert term.triple.s.iri == "http://example.org/Alice"
assert term.triple.o.iri == "http://example.org/Bob"
def test_quoted_triple_from_pulsar(self, term_tx):
def test_quoted_triple_encode(self, term_tx):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
)
term = Term(type=TRIPLE, triple=inner)
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire["t"] == "t"
assert "tr" in wire
assert wire["tr"]["s"]["i"] == "http://example.org/s"
@ -186,18 +186,18 @@ class TestTermTranslatorQuotedTriple:
o=Term(type=LITERAL, value="C", language="en"),
)
original = Term(type=TRIPLE, triple=inner)
wire = term_tx.from_pulsar(original)
restored = term_tx.to_pulsar(wire)
wire = term_tx.encode(original)
restored = term_tx.decode(wire)
assert restored.type == TRIPLE
assert restored.triple.s == original.triple.s
assert restored.triple.o == original.triple.o
def test_quoted_triple_none_triple(self, term_tx):
term = Term(type=TRIPLE, triple=None)
wire = term_tx.from_pulsar(term)
wire = term_tx.encode(term)
assert wire == {"t": "t"}
# And back
restored = term_tx.to_pulsar(wire)
restored = term_tx.decode(wire)
assert restored.type == TRIPLE
assert restored.triple is None
@ -210,7 +210,7 @@ class TestTermTranslatorQuotedTriple:
"o": {"t": "l", "v": "A feeling of expectation"},
},
}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.triple.o.type == LITERAL
assert term.triple.o.value == "A feeling of expectation"
@ -223,22 +223,22 @@ class TestTermTranslatorEdgeCases:
def test_unknown_type(self, term_tx):
data = {"t": "z"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == "z"
def test_empty_type(self, term_tx):
data = {}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.type == ""
def test_missing_iri_field(self, term_tx):
data = {"t": "i"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.iri == ""
def test_missing_literal_fields(self, term_tx):
data = {"t": "l"}
term = term_tx.to_pulsar(data)
term = term_tx.decode(data)
assert term.value == ""
assert term.datatype == ""
assert term.language == ""
@ -250,24 +250,24 @@ class TestTermTranslatorEdgeCases:
class TestTripleTranslator:
def test_triple_to_pulsar(self, triple_tx):
def test_triple_decode(self, triple_tx):
data = {
"s": {"t": "i", "i": "http://example.org/s"},
"p": {"t": "i", "i": "http://example.org/p"},
"o": {"t": "l", "v": "object"},
}
triple = triple_tx.to_pulsar(data)
triple = triple_tx.decode(data)
assert triple.s.iri == "http://example.org/s"
assert triple.o.value == "object"
assert triple.g is None
def test_triple_from_pulsar(self, triple_tx):
def test_triple_encode(self, triple_tx):
triple = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C"),
)
wire = triple_tx.from_pulsar(triple)
wire = triple_tx.encode(triple)
assert wire["s"]["t"] == "i"
assert wire["o"]["v"] == "C"
assert "g" not in wire
@ -279,17 +279,17 @@ class TestTripleTranslator:
"o": {"t": "l", "v": "val"},
"g": "urn:graph:source",
}
quad = triple_tx.to_pulsar(data)
quad = triple_tx.decode(data)
assert quad.g == "urn:graph:source"
def test_quad_from_pulsar_includes_graph(self, triple_tx):
def test_quad_encode_includes_graph(self, triple_tx):
quad = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="v"),
g="urn:graph:retrieval",
)
wire = triple_tx.from_pulsar(quad)
wire = triple_tx.encode(quad)
assert wire["g"] == "urn:graph:retrieval"
def test_quad_round_trip(self, triple_tx):
@ -299,8 +299,8 @@ class TestTripleTranslator:
o=Term(type=LITERAL, value="v"),
g="urn:graph:source",
)
wire = triple_tx.from_pulsar(original)
restored = triple_tx.to_pulsar(wire)
wire = triple_tx.encode(original)
restored = triple_tx.decode(wire)
assert restored == original
def test_none_graph_omitted_from_wire(self, triple_tx):
@ -310,12 +310,12 @@ class TestTripleTranslator:
o=Term(type=LITERAL, value="v"),
g=None,
)
wire = triple_tx.from_pulsar(triple)
wire = triple_tx.encode(triple)
assert "g" not in wire
def test_missing_terms_handled(self, triple_tx):
data = {}
triple = triple_tx.to_pulsar(data)
triple = triple_tx.decode(data)
assert triple.s is None
assert triple.p is None
assert triple.o is None
@ -342,16 +342,16 @@ class TestSubgraphTranslator:
g="urn:graph:source",
),
]
wire_list = tx.from_pulsar(triples)
wire_list = tx.encode(triples)
assert len(wire_list) == 2
assert wire_list[1]["g"] == "urn:graph:source"
restored = tx.to_pulsar(wire_list)
restored = tx.decode(wire_list)
assert len(restored) == 2
assert restored[0] == triples[0]
assert restored[1] == triples[1]
def test_empty_subgraph(self):
tx = SubgraphTranslator()
assert tx.to_pulsar([]) == []
assert tx.from_pulsar([]) == []
assert tx.decode([]) == []
assert tx.encode([]) == []

View file

@ -35,7 +35,7 @@ class TestDocumentMetadataTranslator:
"parent-id": "doc-100",
"document-type": "page",
}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert obj.id == "doc-123"
assert obj.time == 1710000000
assert obj.kind == "application/pdf"
@ -45,14 +45,14 @@ class TestDocumentMetadataTranslator:
assert obj.parent_id == "doc-100"
assert obj.document_type == "page"
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert wire["id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["parent-id"] == "doc-100"
assert wire["document-type"] == "page"
def test_defaults_for_missing_fields(self):
obj = self.tx.to_pulsar({})
obj = self.tx.decode({})
assert obj.parent_id == ""
assert obj.document_type == "source"
@ -63,25 +63,25 @@ class TestDocumentMetadataTranslator:
"o": {"t": "i", "i": "http://example.org/o"},
}]
data = {"metadata": triple_wire}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert len(obj.metadata) == 1
assert obj.metadata[0].s.iri == "http://example.org/s"
def test_none_metadata_handled(self):
data = {"metadata": None}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert obj.metadata == []
def test_empty_tags_preserved(self):
data = {"tags": []}
obj = self.tx.to_pulsar(data)
wire = self.tx.from_pulsar(obj)
obj = self.tx.decode(data)
wire = self.tx.encode(obj)
assert wire["tags"] == []
def test_falsy_fields_omitted_from_wire(self):
"""Empty string fields should be omitted from wire format."""
obj = DocumentMetadata(id="", time=0, user="")
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert "id" not in wire
assert "user" not in wire
@ -105,7 +105,7 @@ class TestProcessingMetadataTranslator:
"collection": "my-collection",
"tags": ["tag1"],
}
obj = self.tx.to_pulsar(data)
obj = self.tx.decode(data)
assert obj.id == "proc-1"
assert obj.document_id == "doc-123"
assert obj.flow == "default"
@ -113,32 +113,32 @@ class TestProcessingMetadataTranslator:
assert obj.collection == "my-collection"
assert obj.tags == ["tag1"]
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert wire["id"] == "proc-1"
assert wire["document-id"] == "doc-123"
assert wire["user"] == "alice"
assert wire["collection"] == "my-collection"
def test_missing_fields_use_defaults(self):
obj = self.tx.to_pulsar({})
obj = self.tx.decode({})
assert obj.id is None
assert obj.user is None
assert obj.collection is None
def test_tags_none_omitted(self):
obj = ProcessingMetadata(tags=None)
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert "tags" not in wire
def test_tags_empty_list_preserved(self):
obj = ProcessingMetadata(tags=[])
wire = self.tx.from_pulsar(obj)
wire = self.tx.encode(obj)
assert wire["tags"] == []
def test_user_and_collection_preserved(self):
"""Core pipeline routing fields must survive round-trip."""
data = {"user": "bob", "collection": "research"}
obj = self.tx.to_pulsar(data)
wire = self.tx.from_pulsar(obj)
obj = self.tx.decode(data)
wire = self.tx.encode(obj)
assert wire["user"] == "bob"
assert wire["collection"] == "research"

View file

@ -28,7 +28,7 @@ class TestRequestTranslation:
}
# Translate to Pulsar
pulsar_msg = translator.to_pulsar(api_data)
pulsar_msg = translator.decode(api_data)
assert pulsar_msg.operation == "schema-selection"
assert pulsar_msg.sample == "test data sample"
@ -46,7 +46,7 @@ class TestRequestTranslation:
"options": {"delimiter": ","}
}
pulsar_msg = translator.to_pulsar(api_data)
pulsar_msg = translator.decode(api_data)
assert pulsar_msg.operation == "generate-descriptor"
assert pulsar_msg.sample == "csv data"
@ -70,7 +70,7 @@ class TestResponseTranslation:
)
# Translate to API format
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "schema-selection"
assert api_data["schema-matches"] == ["products", "inventory", "catalog"]
@ -86,7 +86,7 @@ class TestResponseTranslation:
error=None
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "schema-selection"
assert api_data["schema-matches"] == []
@ -103,7 +103,7 @@ class TestResponseTranslation:
error=None
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "detect-type"
assert api_data["detected-type"] == "xml"
@ -123,7 +123,7 @@ class TestResponseTranslation:
)
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "schema-selection"
# Error objects are typically handled separately by the gateway
@ -146,7 +146,7 @@ class TestResponseTranslation:
error=None
)
api_data = translator.from_pulsar(pulsar_response)
api_data = translator.encode(pulsar_response)
assert api_data["operation"] == "diagnose"
assert api_data["detected-type"] == "csv"
@ -165,7 +165,7 @@ class TestResponseTranslation:
error=None
)
api_data, is_final = translator.from_response_with_completion(pulsar_response)
api_data, is_final = translator.encode_with_completion(pulsar_response)
assert is_final is True # Structured-diag responses are always final
assert api_data["operation"] == "schema-selection"

View file

@ -14,6 +14,7 @@ dependencies = [
"prometheus-client",
"requests",
"python-logging-loki",
"pika",
]
classifiers = [
"Programming Language :: Python :: 3",

View file

@ -81,7 +81,12 @@ from .explainability import (
Synthesis,
Reflection,
Analysis,
Observation,
Conclusion,
Decomposition,
Finding,
Plan,
StepResult,
EdgeSelection,
wire_triples_to_tuples,
extract_term_value,
@ -160,6 +165,7 @@ __all__ = [
"Focus",
"Synthesis",
"Analysis",
"Observation",
"Conclusion",
"EdgeSelection",
"wire_triples_to_tuples",

View file

@ -40,15 +40,25 @@ TG_ANSWER_TYPE = TG + "Answer"
TG_REFLECTION_TYPE = TG + "Reflection"
TG_THOUGHT_TYPE = TG + "Thought"
TG_OBSERVATION_TYPE = TG + "Observation"
TG_TOOL_USE = TG + "ToolUse"
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
TG_AGENT_QUESTION = TG + "AgentQuestion"
# Orchestrator entity types
TG_DECOMPOSITION = TG + "Decomposition"
TG_FINDING = TG + "Finding"
TG_PLAN_TYPE = TG + "Plan"
TG_STEP_RESULT = TG + "StepResult"
# Orchestrator predicates
TG_SUBAGENT_GOAL = TG + "subagentGoal"
TG_PLAN_STEP = TG + "planStep"
# PROV-O predicates
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
PROV_WAS_GENERATED_BY = PROV + "wasGeneratedBy"
RDF_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
@ -82,8 +92,18 @@ class ExplainEntity:
return Exploration.from_triples(uri, triples)
elif TG_FOCUS in types:
return Focus.from_triples(uri, triples)
elif TG_DECOMPOSITION in types:
return Decomposition.from_triples(uri, triples)
elif TG_FINDING in types:
return Finding.from_triples(uri, triples)
elif TG_PLAN_TYPE in types:
return Plan.from_triples(uri, triples)
elif TG_STEP_RESULT in types:
return StepResult.from_triples(uri, triples)
elif TG_SYNTHESIS in types:
return Synthesis.from_triples(uri, triples)
elif TG_OBSERVATION_TYPE in types and TG_REFLECTION_TYPE not in types:
return Observation.from_triples(uri, triples)
elif TG_REFLECTION_TYPE in types:
return Reflection.from_triples(uri, triples)
elif TG_ANALYSIS in types:
@ -261,18 +281,16 @@ class Reflection(ExplainEntity):
@dataclass
class Analysis(ExplainEntity):
"""Analysis entity - one think/act/observe cycle (Agent only)."""
"""Analysis+ToolUse entity - decision + tool call (Agent only)."""
action: str = ""
arguments: str = "" # JSON string
thought: str = ""
observation: str = ""
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis":
action = ""
arguments = ""
thought = ""
observation = ""
for s, p, o in triples:
if p == TG_ACTION:
@ -281,8 +299,6 @@ class Analysis(ExplainEntity):
arguments = o
elif p == TG_THOUGHT:
thought = o
elif p == TG_OBSERVATION:
observation = o
return cls(
uri=uri,
@ -290,7 +306,26 @@ class Analysis(ExplainEntity):
action=action,
arguments=arguments,
thought=thought,
observation=observation
)
@dataclass
class Observation(ExplainEntity):
"""Observation entity - standalone tool result (Agent only)."""
document: str = ""
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Observation":
document = ""
for s, p, o in triples:
if p == TG_DOCUMENT:
document = o
return cls(
uri=uri,
entity_type="observation",
document=document,
)
@ -314,6 +349,70 @@ class Conclusion(ExplainEntity):
)
@dataclass
class Decomposition(ExplainEntity):
"""Decomposition entity - supervisor broke question into sub-goals."""
goals: List[str] = field(default_factory=list)
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Decomposition":
goals = []
for s, p, o in triples:
if p == TG_SUBAGENT_GOAL:
goals.append(o)
return cls(uri=uri, entity_type="decomposition", goals=goals)
@dataclass
class Finding(ExplainEntity):
"""Finding entity - a subagent's result."""
goal: str = ""
document: str = ""
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Finding":
goal = ""
document = ""
for s, p, o in triples:
if p == TG_SUBAGENT_GOAL:
goal = o
elif p == TG_DOCUMENT:
document = o
return cls(uri=uri, entity_type="finding", goal=goal, document=document)
@dataclass
class Plan(ExplainEntity):
"""Plan entity - a structured plan of steps."""
steps: List[str] = field(default_factory=list)
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Plan":
steps = []
for s, p, o in triples:
if p == TG_PLAN_STEP:
steps.append(o)
return cls(uri=uri, entity_type="plan", steps=steps)
@dataclass
class StepResult(ExplainEntity):
"""StepResult entity - a plan step's result."""
step: str = ""
document: str = ""
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "StepResult":
step = ""
document = ""
for s, p, o in triples:
if p == TG_PLAN_STEP:
step = o
elif p == TG_DOCUMENT:
document = o
return cls(uri=uri, entity_type="step-result", step=step, document=document)
def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSelection:
"""Parse triples for an edge selection entity."""
uri = triples[0][0] if triples else ""
@ -675,9 +774,9 @@ class ExplainabilityClient:
return trace
trace["question"] = question
# Find grounding: ?grounding prov:wasGeneratedBy question_uri
# Find grounding: ?grounding prov:wasDerivedFrom question_uri
grounding_triples = self.flow.triples_query(
p=PROV_WAS_GENERATED_BY,
p=PROV_WAS_DERIVED_FROM,
o=question_uri,
g=graph,
user=user,
@ -812,9 +911,9 @@ class ExplainabilityClient:
return trace
trace["question"] = question
# Find grounding: ?grounding prov:wasGeneratedBy question_uri
# Find grounding: ?grounding prov:wasDerivedFrom question_uri
grounding_triples = self.flow.triples_query(
p=PROV_WAS_GENERATED_BY,
p=PROV_WAS_DERIVED_FROM,
o=question_uri,
g=graph,
user=user,
@ -895,7 +994,10 @@ class ExplainabilityClient:
"""
Fetch the complete Agent trace starting from a session URI.
Follows the provenance chain: Question -> Analysis(s) -> Conclusion
Follows the provenance chain for all patterns:
- ReAct: Question -> Analysis(s) -> Conclusion
- Supervisor: Question -> Decomposition -> Finding(s) -> Synthesis
- Plan-then-Execute: Question -> Plan -> StepResult(s) -> Synthesis
Args:
session_uri: The agent session/question URI
@ -906,15 +1008,14 @@ class ExplainabilityClient:
max_content: Maximum content length for conclusion
Returns:
Dict with question, iterations (Analysis list), conclusion entities
Dict with question, steps (mixed entity list), conclusion/synthesis
"""
if graph is None:
graph = "urn:graph:retrieval"
trace = {
"question": None,
"iterations": [],
"conclusion": None,
"steps": [],
}
# Fetch question/session
@ -923,65 +1024,89 @@ class ExplainabilityClient:
return trace
trace["question"] = question
# Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after
current_uri = session_uri
is_first = True
max_iterations = 50 # Safety limit
for _ in range(max_iterations):
# First hop uses wasGeneratedBy (entity←activity),
# subsequent hops use wasDerivedFrom (entity←entity)
if is_first:
derived_triples = self.flow.triples_query(
p=PROV_WAS_GENERATED_BY,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
# Fall back to wasDerivedFrom for backwards compatibility
if not derived_triples:
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
is_first = False
else:
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
if not derived_triples:
break
derived_uri = extract_term_value(derived_triples[0].get("s", {}))
if not derived_uri:
break
entity = self.fetch_entity(derived_uri, graph, user, collection)
if isinstance(entity, Analysis):
trace["iterations"].append(entity)
current_uri = derived_uri
elif isinstance(entity, Conclusion):
trace["conclusion"] = entity
break
else:
# Unknown entity type, stop
break
# Follow the provenance chain from the question
self._follow_provenance_chain(
session_uri, trace, graph, user, collection,
max_depth=50,
)
return trace
def _follow_provenance_chain(
self, current_uri, trace, graph, user, collection,
max_depth=50,
):
"""Recursively follow the provenance chain, handling branches."""
if max_depth <= 0:
return
# Find entities derived from current_uri
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph, user=user, collection=collection,
limit=20
)
if not derived_triples:
return
derived_uris = [
extract_term_value(t.get("s", {}))
for t in derived_triples
]
for derived_uri in derived_uris:
if not derived_uri:
continue
entity = self.fetch_entity(derived_uri, graph, user, collection)
if entity is None:
continue
if isinstance(entity, (Analysis, Observation, Decomposition,
Finding, Plan, StepResult)):
trace["steps"].append(entity)
# Continue following from this entity
self._follow_provenance_chain(
derived_uri, trace, graph, user, collection,
max_depth=max_depth - 1,
)
elif isinstance(entity, Question):
# Sub-trace: a RAG session linked to this agent step.
# Fetch the full sub-trace and embed it.
if entity.question_type == "graph-rag":
sub_trace = self.fetch_graphrag_trace(
derived_uri, graph, user, collection,
)
elif entity.question_type == "document-rag":
sub_trace = self.fetch_docrag_trace(
derived_uri, graph, user, collection,
)
else:
sub_trace = None
if sub_trace:
trace["steps"].append({
"type": "sub-trace",
"question": entity,
"trace": sub_trace,
})
# Continue from the sub-trace's terminal entity
# (Observation may derive from Synthesis)
terminal = sub_trace.get("synthesis")
if terminal:
self._follow_provenance_chain(
terminal.uri, trace, graph, user, collection,
max_depth=max_depth - 1,
)
elif isinstance(entity, (Conclusion, Synthesis)):
trace["steps"].append(entity)
def list_sessions(
self,
graph: Optional[str] = None,
@ -1021,10 +1146,25 @@ class ExplainabilityClient:
if isinstance(entity, Question):
questions.append(entity)
# Sort by timestamp (newest first)
questions.sort(key=lambda q: q.timestamp or "", reverse=True)
# Filter out sub-traces: sessions that have a wasDerivedFrom link
# (they are child sessions linked to a parent agent iteration)
top_level = []
for q in questions:
parent_triples = self.flow.triples_query(
s=q.uri,
p=PROV_WAS_DERIVED_FROM,
g=graph,
user=user,
collection=collection,
limit=1
)
if not parent_triples:
top_level.append(q)
return questions
# Sort by timestamp (newest first)
top_level.sort(key=lambda q: q.timestamp or "", reverse=True)
return top_level
def detect_session_type(
self,
@ -1066,23 +1206,14 @@ class ExplainabilityClient:
limit=5
)
generated_triples = self.flow.triples_query(
p=PROV_WAS_GENERATED_BY,
o=session_uri,
g=graph,
user=user,
collection=collection,
limit=5
)
all_child_uris = [
extract_term_value(t.get("s", {}))
for t in (derived_triples + generated_triples)
for t in derived_triples
]
for child_uri in all_child_uris:
entity = self.fetch_entity(child_uri, graph, user, collection)
if isinstance(entity, Analysis):
if isinstance(entity, (Analysis, Decomposition, Plan)):
return "agent"
if isinstance(entity, Exploration):
return "graphrag"

View file

@ -1122,6 +1122,45 @@ class FlowInstance:
return result
def sparql_query(
self, query, user="trustgraph", collection="default",
limit=10000
):
"""
Execute a SPARQL query against the knowledge graph.
Args:
query: SPARQL 1.1 query string
user: User/keyspace identifier (default: "trustgraph")
collection: Collection identifier (default: "default")
limit: Safety limit on results (default: 10000)
Returns:
dict with query results. Structure depends on query type:
- SELECT: {"query-type": "select", "variables": [...], "bindings": [...]}
- ASK: {"query-type": "ask", "ask-result": bool}
- CONSTRUCT/DESCRIBE: {"query-type": "construct", "triples": [...]}
Raises:
ProtocolException: If an error occurs
"""
input = {
"query": query,
"user": user,
"collection": collection,
"limit": limit,
}
response = self.request("service/sparql", input)
if "error" in response and response["error"]:
error_type = response["error"].get("type", "unknown")
error_message = response["error"].get("message", "Unknown error")
raise ProtocolException(f"{error_type}: {error_message}")
return response
def nlp_query(self, question, max_results=100):
"""
Convert a natural language question to a GraphQL query.

View file

@ -22,8 +22,9 @@ logger = logging.getLogger(__name__)
# Lower threshold provides progress feedback and resumability on slower connections
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024
# Default chunk size (5MB - S3 multipart minimum)
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024
# Default chunk size (3MB - stays under broker message size limits
# after base64 encoding ~4MB)
DEFAULT_CHUNK_SIZE = 3 * 1024 * 1024
def to_value(x):

View file

@ -366,59 +366,39 @@ class SocketClient:
# Handle GraphRAG/DocRAG message format with message_type
if message_type == "explain":
if include_provenance:
return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
return self._build_provenance_event(resp)
return None
# Handle Agent message format with chunk_type="explain"
if chunk_type == "explain":
if include_provenance:
return ProvenanceEvent(
explain_id=resp.get("explain_id", ""),
explain_graph=resp.get("explain_graph", "")
)
return self._build_provenance_event(resp)
return None
if chunk_type == "thought":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""),
)
elif chunk_type == "observation":
return AgentObservation(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""),
)
elif chunk_type == "answer" or chunk_type == "final-answer":
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False)
end_of_dialog=resp.get("end_of_dialog", False),
message_id=resp.get("message_id", ""),
)
elif chunk_type == "action":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
# Non-streaming agent format: chunk_type is empty but has thought/observation/answer fields
elif resp.get("thought"):
return AgentThought(
content=resp.get("thought", ""),
end_of_message=resp.get("end_of_message", False)
)
elif resp.get("observation"):
return AgentObservation(
content=resp.get("observation", ""),
end_of_message=resp.get("end_of_message", False)
)
elif resp.get("answer"):
return AgentAnswer(
content=resp.get("answer", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False)
)
else:
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk(
@ -427,6 +407,42 @@ class SocketClient:
error=None
)
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
"""Build a ProvenanceEvent from a response dict, parsing inline triples
into an ExplainEntity if available."""
explain_id = resp.get("explain_id", "")
explain_graph = resp.get("explain_graph", "")
raw_triples = resp.get("explain_triples", [])
entity = None
if raw_triples:
try:
from .explainability import ExplainEntity
# Convert wire-format triple dicts to (s, p, o) tuples
parsed = []
for t in raw_triples:
s = t.get("s", {}).get("i", "") if t.get("s") else ""
p = t.get("p", {}).get("i", "") if t.get("p") else ""
o_term = t.get("o", {})
if o_term:
if o_term.get("t") == "i":
o = o_term.get("i", "")
else:
o = o_term.get("v", "")
else:
o = ""
parsed.append((s, p, o))
entity = ExplainEntity.from_triples(explain_id, parsed)
except Exception:
pass
return ProvenanceEvent(
explain_id=explain_id,
explain_graph=explain_graph,
entity=entity,
triples=raw_triples,
)
def close(self) -> None:
"""Close the persistent WebSocket connection."""
if self._loop and not self._loop.is_closed():
@ -826,6 +842,31 @@ class SocketFlowInstance:
else:
yield response
def sparql_query_stream(
self,
query: str,
user: str = "trustgraph",
collection: str = "default",
limit: int = 10000,
batch_size: int = 20,
**kwargs: Any
) -> Iterator[Dict[str, Any]]:
"""Execute a SPARQL query with streaming batches."""
request = {
"query": query,
"user": user,
"collection": collection,
"limit": limit,
"streaming": True,
"batch-size": batch_size,
}
request.update(kwargs)
for response in self.client._send_request_sync(
"sparql", self.flow_id, request, streaming_raw=True
):
yield response
def rows_query(
self,
query: str,

View file

@ -150,8 +150,10 @@ class AgentThought(StreamingChunk):
content: Agent's thought text
end_of_message: True if this completes the current thought
chunk_type: Always "thought"
message_id: Provenance URI of the entity being built
"""
chunk_type: str = "thought"
message_id: str = ""
@dataclasses.dataclass
class AgentObservation(StreamingChunk):
@ -165,8 +167,10 @@ class AgentObservation(StreamingChunk):
content: Observation text describing tool results
end_of_message: True if this completes the current observation
chunk_type: Always "observation"
message_id: Provenance URI of the entity being built
"""
chunk_type: str = "observation"
message_id: str = ""
@dataclasses.dataclass
class AgentAnswer(StreamingChunk):
@ -184,6 +188,7 @@ class AgentAnswer(StreamingChunk):
"""
chunk_type: str = "final-answer"
end_of_dialog: bool = False
message_id: str = ""
@dataclasses.dataclass
class RAGChunk(StreamingChunk):
@ -208,25 +213,47 @@ class ProvenanceEvent:
"""
Provenance event for explainability.
Emitted during GraphRAG queries when explainable mode is enabled.
Emitted during retrieval queries when explainable mode is enabled.
Each event represents a provenance node created during query processing.
Attributes:
explain_id: URI of the provenance node (e.g., urn:trustgraph:question:abc123)
explain_graph: Named graph where provenance triples are stored (e.g., urn:graph:retrieval)
event_type: Type of provenance event (question, exploration, focus, synthesis)
event_type: Type of provenance event (question, exploration, focus, synthesis, etc.)
entity: Parsed ExplainEntity from inline triples (if available)
triples: Raw triples from the response (wire format dicts)
"""
explain_id: str
explain_graph: str = ""
event_type: str = "" # Derived from explain_id
entity: object = None # ExplainEntity (parsed from triples)
triples: list = dataclasses.field(default_factory=list) # Raw wire-format triple dicts
def __post_init__(self):
# Extract event type from explain_id
if "question" in self.explain_id:
self.event_type = "question"
elif "grounding" in self.explain_id:
self.event_type = "grounding"
elif "exploration" in self.explain_id:
self.event_type = "exploration"
elif "focus" in self.explain_id:
self.event_type = "focus"
elif "synthesis" in self.explain_id:
self.event_type = "synthesis"
elif "iteration" in self.explain_id:
self.event_type = "iteration"
elif "observation" in self.explain_id:
self.event_type = "observation"
elif "conclusion" in self.explain_id:
self.event_type = "conclusion"
elif "decomposition" in self.explain_id:
self.event_type = "decomposition"
elif "finding" in self.explain_id:
self.event_type = "finding"
elif "plan" in self.explain_id:
self.event_type = "plan"
elif "step-result" in self.explain_id:
self.event_type = "step-result"
elif "session" in self.explain_id:
self.event_type = "session"

View file

@ -1,5 +1,5 @@
from . pubsub import PulsarClient, get_pubsub
from . pubsub import get_pubsub, add_pubsub_args
from . async_processor import AsyncProcessor
from . consumer import Consumer
from . producer import Producer
@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec
from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult, LlmChunk
from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec

View file

@ -57,8 +57,7 @@ class AgentClient(RequestResponse):
await self.request(
AgentRequest(
question = question,
plan = plan,
state = state,
state = state or "",
history = history,
),
recipient=recipient,

View file

@ -90,9 +90,6 @@ class AgentService(FlowProcessor):
type = "agent-error",
message = str(e),
),
thought = None,
observation = None,
answer = None,
end_of_message = True,
end_of_dialog = True,
),

View file

@ -1,24 +1,29 @@
# Base class for processors. Implements:
# - Pulsar client, subscribe and consume basic
# - Pub/sub client, subscribe and consume basic
# - the async startup logic
# - Config notify handling with subscribe-then-fetch pattern
# - Initialising metrics
import asyncio
import argparse
import _pulsar
import time
import uuid
import logging
import os
from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue
from .. schema import ConfigPush, ConfigRequest, ConfigResponse
from .. schema import config_push_queue, config_request_queue
from .. schema import config_response_queue
from .. log_level import LogLevel
from . pubsub import PulsarClient, get_pubsub
from . pubsub import get_pubsub, add_pubsub_args
from . producer import Producer
from . consumer import Consumer
from . metrics import ProcessorMetrics, ConsumerMetrics
from . subscriber import Subscriber
from . request_response_spec import RequestResponse
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from . metrics import SubscriberMetrics
from . logging import add_logging_args, setup_logging
default_config_queue = config_push_queue
@ -58,9 +63,13 @@ class AsyncProcessor:
"config_push_queue", default_config_queue
)
# This records registered configuration handlers
# This records registered configuration handlers, each entry is:
# { "handler": async_fn, "types": set_or_none }
self.config_handlers = []
# Track the current config version for dedup
self.config_version = 0
# Create a random ID for this subscription to the configuration
# service
config_subscriber_id = str(uuid.uuid4())
@ -69,33 +78,104 @@ class AsyncProcessor:
processor = self.id, flow = None, name = "config",
)
# Subscribe to config queue
# Subscribe to config notify queue
self.config_sub_task = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub_backend, # Changed from client to backend
backend = self.pubsub_backend,
subscriber = config_subscriber_id,
flow = None,
topic = self.config_push_queue,
schema = ConfigPush,
handler = self.on_config_change,
handler = self.on_config_notify,
metrics = config_consumer_metrics,
# This causes new subscriptions to view the entire history of
# configuration
start_of_messages = True
start_of_messages = False,
)
self.running = True
# This is called to start dynamic behaviour. An over-ride point for
# extra functionality
def _create_config_client(self):
"""Create a short-lived config request/response client."""
config_rr_id = str(uuid.uuid4())
config_req_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "config-request",
)
config_resp_metrics = SubscriberMetrics(
processor = self.id, flow = None, name = "config-response",
)
return RequestResponse(
backend = self.pubsub_backend,
subscription = f"{self.id}--config--{config_rr_id}",
consumer_name = self.id,
request_topic = config_request_queue,
request_schema = ConfigRequest,
request_metrics = config_req_metrics,
response_topic = config_response_queue,
response_schema = ConfigResponse,
response_metrics = config_resp_metrics,
)
async def fetch_config(self):
"""Fetch full config from config service using a short-lived
request/response client. Returns (config, version) or raises."""
client = self._create_config_client()
try:
await client.start()
resp = await client.request(
ConfigRequest(operation="config"),
timeout=10,
)
if resp.error:
raise RuntimeError(f"Config error: {resp.error.message}")
return resp.config, resp.version
finally:
await client.stop()
# This is called to start dynamic behaviour.
# Implements the subscribe-then-fetch pattern to avoid race conditions.
async def start(self):
# 1. Start the notify consumer (begins buffering incoming notifys)
await self.config_sub_task.start()
# 2. Fetch current config via request/response
await self.fetch_and_apply_config()
# 3. Any buffered notifys with version > fetched version will be
# processed by on_config_notify, which does the version check
async def fetch_and_apply_config(self):
"""Fetch full config from config service and apply to all handlers.
Retries until successful config service may not be ready yet."""
while self.running:
try:
config, version = await self.fetch_config()
logger.info(f"Fetched config version {version}")
self.config_version = version
# Apply to all handlers (startup = invoke all)
for entry in self.config_handlers:
await entry["handler"](config, version)
return
except Exception as e:
logger.warning(
f"Config fetch failed: {e}, retrying in 2s...",
exc_info=True
)
await asyncio.sleep(2)
# This is called to stop all threads. An over-ride point for extra
# functionality
def stop(self):
@ -111,20 +191,66 @@ class AsyncProcessor:
def pulsar_host(self): return self._pulsar_host
# Register a new event handler for configuration change
def register_config_handler(self, handler):
self.config_handlers.append(handler)
def register_config_handler(self, handler, types=None):
self.config_handlers.append({
"handler": handler,
"types": set(types) if types else None,
})
# Called when a new configuration message push occurs
async def on_config_change(self, message, consumer, flow):
# Called when a config notify message arrives
async def on_config_notify(self, message, consumer, flow):
# Get configuration data and version number
config = message.value().config
version = message.value().version
notify_version = message.value().version
notify_types = set(message.value().types)
# Invoke message handlers
logger.info(f"Config change event: version={version}")
for ch in self.config_handlers:
await ch(config, version)
# Skip if we already have this version or newer
if notify_version <= self.config_version:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"already at v{self.config_version}"
)
return
# Check if any handler cares about the affected types
if notify_types:
any_interested = False
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types is None or notify_types & handler_types:
any_interested = True
break
if not any_interested:
logger.debug(
f"Ignoring config notify v{notify_version}, "
f"no handlers for types {notify_types}"
)
self.config_version = notify_version
return
logger.info(
f"Config notify v{notify_version} types={list(notify_types)}, "
f"fetching config..."
)
# Fetch full config using short-lived client
try:
config, version = await self.fetch_config()
self.config_version = version
# Invoke handlers that care about the affected types
for entry in self.config_handlers:
handler_types = entry["types"]
if handler_types is None:
await entry["handler"](config, version)
elif not notify_types or notify_types & handler_types:
await entry["handler"](config, version)
except Exception as e:
logger.error(
f"Failed to fetch config on notify: {e}", exc_info=True
)
# This is the 'main' body of the handler. It is a point to override
# if needed. By default does nothing. Processors are implemented
@ -182,7 +308,7 @@ class AsyncProcessor:
prog=ident,
description=doc
)
parser.add_argument(
'--id',
default=ident,
@ -223,8 +349,8 @@ class AsyncProcessor:
logger.info("Keyboard interrupt.")
return
except _pulsar.Interrupted:
logger.info("Pulsar Interrupted.")
except KeyboardInterrupt:
logger.info("Interrupted.")
return
# Exceptions from a taskgroup come in as an exception group
@ -250,15 +376,7 @@ class AsyncProcessor:
@staticmethod
def add_args(parser):
# Pub/sub backend selection
parser.add_argument(
'--pubsub-backend',
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
choices=['pulsar', 'mqtt'],
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
)
PulsarClient.add_args(parser)
add_pubsub_args(parser)
add_logging_args(parser)
parser.add_argument(
@ -280,4 +398,3 @@ class AsyncProcessor:
default=8000,
help=f'Pulsar host (default: 8000)',
)

View file

@ -7,23 +7,14 @@ fetching large document content.
import asyncio
import base64
import logging
import uuid
from .flow_processor import FlowProcessor
from .parameter_spec import ParameterSpec
from .consumer import Consumer
from .producer import Producer
from .metrics import ConsumerMetrics, ProducerMetrics
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ..schema import librarian_request_queue, librarian_response_queue
from .librarian_client import LibrarianClient
# Module logger
logger = logging.getLogger(__name__)
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class ChunkingService(FlowProcessor):
"""Base service for chunking processors with parameter specification support"""
@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor):
ParameterSpec(name="chunk-overlap")
)
# Librarian client for fetching document content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
# Librarian client
self.librarian = LibrarianClient(
id=id,
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.debug("ChunkingService initialized with parameter specifications")
async def start(self):
await super(ChunkingService, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="chunk", title=None, timeout=120):
"""
Save a child document (chunk) to the librarian.
Args:
doc_id: ID for the new child document
parent_id: ID of the parent document
user: User ID
content: Document content (bytes or str)
document_type: Type of document ("chunk", etc.)
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving chunk: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving chunk {doc_id}")
await self.librarian.start()
async def get_document_text(self, doc):
"""
@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor):
"""
if doc.document_id and not doc.text:
logger.info(f"Fetching document {doc.document_id} from librarian...")
content = await self.fetch_document_content(
text = await self.librarian.fetch_document_text(
document_id=doc.document_id,
user=doc.metadata.user,
)
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
text = base64.b64decode(content).decode("utf-8")
logger.info(f"Fetched {len(text)} characters from librarian")
return text
else:
@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor):
Extract chunk parameters from flow and return effective values
Args:
msg: The message containing the document to chunk
consumer: The consumer spec
flow: The flow context
default_chunk_size: Default chunk size from processor config
default_chunk_overlap: Default chunk overlap from processor config
msg: The message being processed
consumer: The consumer instance
flow: The flow object containing parameters
default_chunk_size: Default chunk size if not configured
default_chunk_overlap: Default chunk overlap if not configured
Returns:
tuple: (chunk_size, chunk_overlap) - effective values to use
tuple: (chunk_size, chunk_overlap) effective values
"""
# Extract parameters from flow (flow-configurable parameters)
chunk_size = flow("chunk-size")
chunk_overlap = flow("chunk-overlap")
# Use provided values or fall back to defaults
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
chunk_size = default_chunk_size
chunk_overlap = default_chunk_overlap
logger.debug(f"Using chunk-size: {effective_chunk_size}")
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}")
try:
cs = flow.parameters.get("chunk-size")
if cs is not None:
chunk_size = int(cs)
except Exception as e:
logger.warning(f"Could not parse chunk-size parameter: {e}")
return effective_chunk_size, effective_chunk_overlap
try:
co = flow.parameters.get("chunk-overlap")
if co is not None:
chunk_overlap = int(co)
except Exception as e:
logger.warning(f"Could not parse chunk-overlap parameter: {e}")
@staticmethod
def add_args(parser):
"""Add chunking service arguments to parser"""
FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
return chunk_size, chunk_overlap

View file

@ -12,6 +12,7 @@
import asyncio
import time
import logging
from concurrent.futures import ThreadPoolExecutor
from .. exceptions import TooManyRequests
@ -32,6 +33,7 @@ class Consumer:
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
reconnect_time = 5,
concurrency = 1, # Number of concurrent requests to handle
consumer_type = 'shared',
):
self.taskgroup = taskgroup
@ -42,6 +44,8 @@ class Consumer:
self.schema = schema
self.handler = handler
self.consumer_type = consumer_type
self.rate_limit_retry_time = rate_limit_retry_time
self.rate_limit_timeout = rate_limit_timeout
@ -93,33 +97,11 @@ class Consumer:
if self.metrics:
self.metrics.state("stopped")
try:
logger.info(f"Subscribing to topic: {self.topic}")
# Determine initial position
if self.start_of_messages:
initial_pos = 'earliest'
else:
initial_pos = 'latest'
# Create consumer via backend
self.consumer = await asyncio.to_thread(
self.backend.create_consumer,
topic = self.topic,
subscription = self.subscriber,
schema = self.schema,
initial_position = initial_pos,
consumer_type = 'shared',
)
except Exception as e:
logger.error(f"Consumer subscription exception: {e}", exc_info=True)
await asyncio.sleep(self.reconnect_time)
continue
logger.info(f"Successfully subscribed to topic: {self.topic}")
# Determine initial position
if self.start_of_messages:
initial_pos = 'earliest'
else:
initial_pos = 'latest'
if self.metrics:
self.metrics.state("running")
@ -128,14 +110,38 @@ class Consumer:
logger.info(f"Starting {self.concurrency} receiver threads")
async with asyncio.TaskGroup() as tg:
tasks = []
for i in range(0, self.concurrency):
tasks.append(
tg.create_task(self.consume_from_queue())
# Create one backend consumer per concurrent task.
# Each gets its own connection and dedicated thread —
# required for backends like RabbitMQ where connections
# are not thread-safe (pika BlockingConnection must be
# used from a single thread).
consumers = []
executors = []
for i in range(self.concurrency):
try:
logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
executor = ThreadPoolExecutor(max_workers=1)
loop = asyncio.get_event_loop()
c = await loop.run_in_executor(
executor,
lambda: self.backend.create_consumer(
topic = self.topic,
subscription = self.subscriber,
schema = self.schema,
initial_position = initial_pos,
consumer_type = self.consumer_type,
),
)
consumers.append(c)
executors.append(executor)
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
except Exception as e:
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
raise
async with asyncio.TaskGroup() as tg:
for c, ex in zip(consumers, executors):
tg.create_task(self.consume_from_queue(c, ex))
if self.metrics:
self.metrics.state("stopped")
@ -143,24 +149,38 @@ class Consumer:
except Exception as e:
logger.error(f"Consumer loop exception: {e}", exc_info=True)
self.consumer.unsubscribe()
self.consumer.close()
self.consumer = None
for c in consumers:
try:
c.unsubscribe()
c.close()
except Exception:
pass
for ex in executors:
ex.shutdown(wait=False)
consumers = []
executors = []
await asyncio.sleep(self.reconnect_time)
continue
if self.consumer:
self.consumer.unsubscribe()
self.consumer.close()
finally:
for c in consumers:
try:
c.unsubscribe()
c.close()
except Exception:
pass
for ex in executors:
ex.shutdown(wait=False)
async def consume_from_queue(self):
async def consume_from_queue(self, consumer, executor=None):
loop = asyncio.get_event_loop()
while self.running:
try:
msg = await asyncio.to_thread(
self.consumer.receive,
timeout_millis=100
msg = await loop.run_in_executor(
executor,
lambda: consumer.receive(timeout_millis=100),
)
except Exception as e:
# Handle timeout from any backend
@ -168,10 +188,11 @@ class Consumer:
continue
raise e
await self.handle_one_from_queue(msg)
await self.handle_one_from_queue(msg, consumer, executor)
async def handle_one_from_queue(self, msg):
async def handle_one_from_queue(self, msg, consumer, executor=None):
loop = asyncio.get_event_loop()
expiry = time.time() + self.rate_limit_timeout
# This loop is for retry on rate-limit / resource limits
@ -182,8 +203,11 @@ class Consumer:
logger.warning("Gave up waiting for rate-limit retry")
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
# be retried. Ack on the consumer's dedicated thread
# (pika is not thread-safe).
await loop.run_in_executor(
executor, lambda: consumer.negative_acknowledge(msg)
)
if self.metrics:
self.metrics.process("error")
@ -205,8 +229,11 @@ class Consumer:
logger.debug("Message processed successfully")
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
# Acknowledge on the consumer's dedicated thread
# (pika is not thread-safe)
await loop.run_in_executor(
executor, lambda: consumer.acknowledge(msg)
)
if self.metrics:
self.metrics.process("success")
@ -232,8 +259,10 @@ class Consumer:
logger.error(f"Message processing exception: {e}", exc_info=True)
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
# be retried. Ack on the consumer's dedicated thread.
await loop.run_in_executor(
executor, lambda: consumer.negative_acknowledge(msg)
)
if self.metrics:
self.metrics.process("error")

View file

@ -6,8 +6,6 @@
import json
import logging
from pulsar.schema import JsonSchema
from .. schema import Error
from .. schema import config_request_queue, config_response_queue
from .. schema import config_push_queue
@ -28,7 +26,9 @@ class FlowProcessor(AsyncProcessor):
super(FlowProcessor, self).__init__(**params)
# Register configuration handler
self.register_config_handler(self.on_configure_flows)
self.register_config_handler(
self.on_configure_flows, types=["active-flow"]
)
# Initialise flow information state
self.flows = {}

View file

@ -5,6 +5,7 @@ from .. schema import GraphRagQuery, GraphRagResponse
class GraphRagClient(RequestResponse):
async def rag(self, query, user="trustgraph", collection="default",
chunk_callback=None, explain_callback=None,
parent_uri="",
timeout=600):
"""
Execute a graph RAG query with optional streaming callbacks.
@ -50,6 +51,7 @@ class GraphRagClient(RequestResponse):
query = query,
user = user,
collection = collection,
parent_uri = parent_uri,
),
timeout=timeout,
recipient=recipient,

View file

@ -0,0 +1,246 @@
"""
Shared librarian client for services that need to communicate
with the librarian via pub/sub.
Provides request-response and streaming operations over the message
broker, with proper support for large documents via stream-document.
Usage:
self.librarian = LibrarianClient(
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
)
await self.librarian.start()
content = await self.librarian.fetch_document_content(doc_id, user)
"""
import asyncio
import base64
import logging
import uuid
from .consumer import Consumer
from .producer import Producer
from .metrics import ConsumerMetrics, ProducerMetrics
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ..schema import librarian_request_queue, librarian_response_queue
logger = logging.getLogger(__name__)
class LibrarianClient:
"""Client for librarian request-response over the message broker."""
def __init__(self, id, backend, taskgroup, **params):
librarian_request_q = params.get(
"librarian_request_queue", librarian_request_queue,
)
librarian_response_q = params.get(
"librarian_response_queue", librarian_response_queue,
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request",
)
self._producer = Producer(
backend=backend,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response",
)
self._consumer = Consumer(
taskgroup=taskgroup,
backend=backend,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self._on_response,
metrics=librarian_response_metrics,
consumer_type='exclusive',
)
# Single-response requests: request_id -> asyncio.Future
self._pending = {}
# Streaming requests: request_id -> asyncio.Queue
self._streams = {}
async def start(self):
"""Start the librarian producer and consumer."""
await self._producer.start()
await self._consumer.start()
async def _on_response(self, msg, consumer, flow):
"""Route librarian responses to the right waiter."""
response = msg.value()
request_id = msg.properties().get("id")
if not request_id:
return
if request_id in self._pending:
future = self._pending.pop(request_id)
future.set_result(response)
elif request_id in self._streams:
await self._streams[request_id].put(response)
async def request(self, request, timeout=120):
"""Send a request to the librarian and wait for a single response."""
request_id = str(uuid.uuid4())
future = asyncio.get_event_loop().create_future()
self._pending[request_id] = future
try:
await self._producer.send(
request, properties={"id": request_id},
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
return response
except asyncio.TimeoutError:
self._pending.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian response")
async def stream(self, request, timeout=120):
"""Send a request and collect streamed response chunks."""
request_id = str(uuid.uuid4())
q = asyncio.Queue()
self._streams[request_id] = q
try:
await self._producer.send(
request, properties={"id": request_id},
)
chunks = []
while True:
response = await asyncio.wait_for(q.get(), timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
chunks.append(response)
if response.is_final:
break
return chunks
except asyncio.TimeoutError:
self._streams.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian stream")
finally:
self._streams.pop(request_id, None)
async def fetch_document_content(self, document_id, user, timeout=120):
"""Fetch document content using streaming.
Returns base64-encoded content. Caller is responsible for decoding.
"""
req = LibrarianRequest(
operation="stream-document",
document_id=document_id,
user=user,
)
chunks = await self.stream(req, timeout=timeout)
# Decode each chunk's base64 to raw bytes, concatenate,
# re-encode for the caller.
raw = b""
for chunk in chunks:
if chunk.content:
if isinstance(chunk.content, bytes):
raw += base64.b64decode(chunk.content)
else:
raw += base64.b64decode(
chunk.content.encode("utf-8")
)
return base64.b64encode(raw)
async def fetch_document_text(self, document_id, user, timeout=120):
"""Fetch document content and decode as UTF-8 text."""
content = await self.fetch_document_content(
document_id, user, timeout=timeout,
)
return base64.b64decode(content).decode("utf-8")
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""Fetch document metadata from the librarian."""
req = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
response = await self.request(req, timeout=timeout)
return response.document_metadata
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="chunk", title=None,
kind="text/plain", timeout=120):
"""Save a child document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
req = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
await self.request(req, timeout=timeout)
return doc_id
async def save_document(self, doc_id, user, content, title=None,
document_type="answer", kind="text/plain",
timeout=120):
"""Save a document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
document_type=document_type,
)
req = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
user=user,
)
await self.request(req, timeout=timeout)
return doc_id

View file

@ -1,21 +1,16 @@
import json
import asyncio
import logging
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
logger = logging.getLogger(__name__)
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
logger.info(f"DEBUG prompt_client: prompt called, id={id}, streaming={streaming}, chunk_callback={chunk_callback is not None}")
if not streaming:
logger.info("DEBUG prompt_client: Non-streaming path")
# Non-streaming path
resp = await self.request(
PromptRequest(
id = id,
@ -36,39 +31,30 @@ class PromptClient(RequestResponse):
return json.loads(resp.object)
else:
logger.info("DEBUG prompt_client: Streaming path")
# Streaming path - just forward chunks, don't accumulate
last_text = ""
last_object = None
async def forward_chunks(resp):
nonlocal last_text, last_object
logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
if resp.error:
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
raise RuntimeError(resp.error.message)
end_stream = getattr(resp, 'end_of_stream', False)
# Always call callback if there's text OR if it's the final message
if resp.text is not None:
last_text = resp.text
# Call chunk callback if provided with both chunk and end_of_stream flag
if chunk_callback:
logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text, end_stream)
else:
chunk_callback(resp.text, end_stream)
elif resp.object:
logger.info(f"DEBUG prompt_client: Got object response")
last_object = resp.object
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
return end_stream
logger.info("DEBUG prompt_client: Creating PromptRequest")
req = PromptRequest(
id = id,
terms = {
@ -77,19 +63,16 @@ class PromptClient(RequestResponse):
},
streaming = True
)
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
await self.request(
req,
recipient=forward_chunks,
timeout=timeout
)
logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}")
if last_text:
logger.info("DEBUG prompt_client: Returning last_text")
return last_text
logger.info("DEBUG prompt_client: Returning parsed last_object")
return json.loads(last_object) if last_object else None
async def extract_definitions(self, text, timeout=600):

View file

@ -1,110 +1,121 @@
import os
import pulsar
import _pulsar
import uuid
from pulsar.schema import JsonSchema
import logging
from .. log_level import LogLevel
from .pulsar_backend import PulsarBackend
logger = logging.getLogger(__name__)
# Default connection settings from environment
DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None)
DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq')
DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672'))
DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest')
DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest')
DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/')
def get_pubsub(**config):
"""
Factory function to create a pub/sub backend based on configuration.
Args:
config: Configuration dictionary from command-line args
Must include 'pubsub_backend' key
config: Configuration dictionary from command-line args.
Key 'pubsub_backend' selects the backend (default: 'pulsar').
Returns:
Backend instance (PulsarBackend, MQTTBackend, etc.)
Example:
backend = get_pubsub(
pubsub_backend='pulsar',
pulsar_host='pulsar://localhost:6650'
)
Backend instance implementing the PubSubBackend protocol.
"""
backend_type = config.get('pubsub_backend', 'pulsar')
if backend_type == 'pulsar':
from .pulsar_backend import PulsarBackend
return PulsarBackend(
host=config.get('pulsar_host', PulsarClient.default_pulsar_host),
api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key),
host=config.get('pulsar_host', DEFAULT_PULSAR_HOST),
api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY),
listener=config.get('pulsar_listener'),
)
elif backend_type == 'mqtt':
# TODO: Implement MQTT backend
raise NotImplementedError("MQTT backend not yet implemented")
elif backend_type == 'rabbitmq':
from .rabbitmq_backend import RabbitMQBackend
return RabbitMQBackend(
host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST),
port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT),
username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME),
password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD),
vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST),
)
else:
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
class PulsarClient:
STANDALONE_PULSAR_HOST = 'pulsar://localhost:6650'
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
def __init__(self, **params):
def add_pubsub_args(parser, standalone=False):
"""Add pub/sub CLI arguments to an argument parser.
self.client = None
Args:
parser: argparse.ArgumentParser
standalone: If True, default host is localhost (for CLI tools
that run outside containers)
"""
pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
pulsar_listener = 'localhost' if standalone else None
rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
pulsar_listener = params.get("pulsar_listener", None)
pulsar_api_key = params.get(
"pulsar_api_key",
self.default_pulsar_api_key
)
# Hard-code Pulsar logging to ERROR level to minimize noise
parser.add_argument(
'--pubsub-backend',
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
)
self.pulsar_host = pulsar_host
self.pulsar_api_key = pulsar_api_key
# Pulsar options
parser.add_argument(
'-p', '--pulsar-host',
default=pulsar_host,
help=f'Pulsar host (default: {pulsar_host})',
)
if pulsar_api_key:
auth = pulsar.AuthenticationToken(pulsar_api_key)
self.client = pulsar.Client(
pulsar_host,
authentication=auth,
logger=pulsar.ConsoleLogger(_pulsar.LoggerLevel.Error)
)
else:
self.client = pulsar.Client(
pulsar_host,
listener_name=pulsar_listener,
logger=pulsar.ConsoleLogger(_pulsar.LoggerLevel.Error)
)
parser.add_argument(
'--pulsar-api-key',
default=DEFAULT_PULSAR_API_KEY,
help='Pulsar API key',
)
self.pulsar_listener = pulsar_listener
parser.add_argument(
'--pulsar-listener',
default=pulsar_listener,
help=f'Pulsar listener (default: {pulsar_listener or "none"})',
)
def close(self):
self.client.close()
# RabbitMQ options
parser.add_argument(
'--rabbitmq-host',
default=rabbitmq_host,
help=f'RabbitMQ host (default: {rabbitmq_host})',
)
def __del__(self):
parser.add_argument(
'--rabbitmq-port',
type=int,
default=DEFAULT_RABBITMQ_PORT,
help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})',
)
if hasattr(self, "client"):
if self.client:
self.client.close()
parser.add_argument(
'--rabbitmq-username',
default=DEFAULT_RABBITMQ_USERNAME,
help='RabbitMQ username',
)
@staticmethod
def add_args(parser):
parser.add_argument(
'--rabbitmq-password',
default=DEFAULT_RABBITMQ_PASSWORD,
help='RabbitMQ password',
)
parser.add_argument(
'-p', '--pulsar-host',
default=__class__.default_pulsar_host,
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=__class__.default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'--pulsar-listener',
help=f'Pulsar listener (default: none)',
)
parser.add_argument(
'--rabbitmq-vhost',
default=DEFAULT_RABBITMQ_VHOST,
help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})',
)

View file

@ -9,122 +9,14 @@ import pulsar
import _pulsar
import json
import logging
import base64
import types
from dataclasses import asdict, is_dataclass
from typing import Any, get_type_hints
from typing import Any
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
from .serialization import dataclass_to_dict, dict_to_dataclass
logger = logging.getLogger(__name__)
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
Handles nested dataclasses, lists, and dictionaries recursively.
"""
if obj is None:
return None
# Handle bytes - decode to UTF-8 for JSON serialization
if isinstance(obj, bytes):
return obj.decode('utf-8')
# Handle dataclass - convert to dict then recursively process all values
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
result[key] = dataclass_to_dict(value) if value is not None else None
return result
# Handle list - recursively process all items
if isinstance(obj, list):
return [dataclass_to_dict(item) for item in obj]
# Handle dict - recursively process all values
if isinstance(obj, dict):
return {k: dataclass_to_dict(v) for k, v in obj.items()}
# Return primitive types as-is
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
Uses get_type_hints() to resolve forward references (string annotations).
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass, resolving forward references
# get_type_hints() evaluates string annotations like "Triple | None"
try:
field_types = get_type_hints(cls)
except Exception:
# Fallback if get_type_hints fails (shouldn't happen normally)
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)
class PulsarMessage:
"""Wrapper for Pulsar messages to match Message protocol."""
@ -181,8 +73,11 @@ class PulsarBackendConsumer:
self._schema_cls = schema_cls
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message."""
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
"""Receive a message. Raises TimeoutError if no message available."""
try:
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
except _pulsar.Timeout:
raise TimeoutError("No message received within timeout")
return PulsarMessage(pulsar_msg, self._schema_cls)
def acknowledge(self, message: Message) -> None:
@ -237,38 +132,44 @@ class PulsarBackend:
self.client = pulsar.Client(**client_args)
logger.info(f"Pulsar client connected to {host}")
def map_topic(self, generic_topic: str) -> str:
def map_topic(self, queue_id: str) -> str:
"""
Map generic topic format to Pulsar URI.
Map queue identifier to Pulsar URI.
Format: qos/tenant/namespace/queue
Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue
Format: class:topicspace:topic
Example: flow:tg:text-completion-request -> persistent://tg/flow/text-completion-request
Args:
generic_topic: Generic topic string or already-formatted Pulsar URI
queue_id: Queue identifier string or already-formatted Pulsar URI
Returns:
Pulsar topic URI
"""
# If already a Pulsar URI, return as-is
if '://' in generic_topic:
return generic_topic
if '://' in queue_id:
return queue_id
parts = generic_topic.split('/', 3)
if len(parts) != 4:
raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue")
parts = queue_id.split(':', 2)
if len(parts) != 3:
raise ValueError(
f"Invalid queue format: {queue_id}, "
f"expected class:topicspace:topic"
)
qos, tenant, namespace, queue = parts
cls, topicspace, topic = parts
# Map QoS to persistence
if qos == 'q0':
persistence = 'non-persistent'
elif qos in ['q1', 'q2']:
# Map class to Pulsar persistence and namespace
if cls in ('flow', 'state'):
persistence = 'persistent'
elif cls in ('request', 'response'):
persistence = 'non-persistent'
else:
raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2")
raise ValueError(
f"Invalid queue class: {cls}, "
f"expected flow, request, response, or state"
)
return f"{persistence}://{tenant}/{namespace}/{queue}"
return f"{persistence}://{topicspace}/{cls}/{topic}"
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
"""

View file

@ -0,0 +1,391 @@
"""
RabbitMQ backend implementation for pub/sub abstraction.
Uses a single topic exchange per topicspace. The logical queue name
becomes the routing key. Consumer behavior is determined by the
subscription name:
- Same subscription + same topic = shared queue (competing consumers)
- Different subscriptions = separate queues (broadcast / fan-out)
This mirrors Pulsar's subscription model using idiomatic RabbitMQ.
Architecture:
Producer --> [tg exchange] --routing key--> [named queue] --> Consumer
--routing key--> [named queue] --> Consumer
--routing key--> [exclusive q] --> Subscriber
Uses basic_consume (push) instead of basic_get (polling) for
efficient message delivery.
"""
import json
import time
import logging
import queue
import threading
import pika
from typing import Any
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
from .serialization import dataclass_to_dict, dict_to_dataclass
logger = logging.getLogger(__name__)
class RabbitMQMessage:
"""Wrapper for RabbitMQ messages to match Message protocol."""
def __init__(self, method, properties, body, schema_cls):
self._method = method
self._properties = properties
self._body = body
self._schema_cls = schema_cls
self._value = None
def value(self) -> Any:
"""Deserialize and return the message value as a dataclass."""
if self._value is None:
data_dict = json.loads(self._body.decode('utf-8'))
self._value = dict_to_dataclass(data_dict, self._schema_cls)
return self._value
def properties(self) -> dict:
"""Return message properties from AMQP headers."""
headers = self._properties.headers or {}
return dict(headers)
class RabbitMQBackendProducer:
"""Publishes messages to a topic exchange with a routing key.
Uses thread-local connections so each thread gets its own
connection/channel. This avoids wire corruption from concurrent
threads writing to the same socket (pika is not thread-safe).
"""
def __init__(self, connection_params, exchange_name, routing_key,
durable):
self._connection_params = connection_params
self._exchange_name = exchange_name
self._routing_key = routing_key
self._durable = durable
self._local = threading.local()
def _get_channel(self):
"""Get or create a thread-local connection and channel."""
conn = getattr(self._local, 'connection', None)
chan = getattr(self._local, 'channel', None)
if conn is None or not conn.is_open or chan is None or not chan.is_open:
# Close stale connection if any
if conn is not None:
try:
conn.close()
except Exception:
pass
conn = pika.BlockingConnection(self._connection_params)
chan = conn.channel()
chan.exchange_declare(
exchange=self._exchange_name,
exchange_type='topic',
durable=True,
)
self._local.connection = conn
self._local.channel = chan
return chan
def send(self, message: Any, properties: dict = {}) -> None:
data_dict = dataclass_to_dict(message)
json_data = json.dumps(data_dict)
amqp_properties = pika.BasicProperties(
delivery_mode=2 if self._durable else 1,
content_type='application/json',
headers=properties if properties else None,
)
for attempt in range(2):
try:
channel = self._get_channel()
channel.basic_publish(
exchange=self._exchange_name,
routing_key=self._routing_key,
body=json_data.encode('utf-8'),
properties=amqp_properties,
)
return
except Exception as e:
logger.warning(
f"RabbitMQ send failed (attempt {attempt + 1}): {e}"
)
# Force reconnect on next attempt
self._local.connection = None
self._local.channel = None
if attempt == 1:
raise
def flush(self) -> None:
pass
def close(self) -> None:
"""Close the thread-local connection if any."""
conn = getattr(self._local, 'connection', None)
if conn is not None:
try:
conn.close()
except Exception:
pass
self._local.connection = None
self._local.channel = None
class RabbitMQBackendConsumer:
"""Consumes from a queue bound to a topic exchange.
Uses basic_consume (push model) with messages delivered to an
internal thread-safe queue. process_data_events() drives both
message delivery and heartbeat processing.
"""
def __init__(self, connection_params, exchange_name, routing_key,
queue_name, schema_cls, durable, exclusive=False,
auto_delete=False):
self._connection_params = connection_params
self._exchange_name = exchange_name
self._routing_key = routing_key
self._queue_name = queue_name
self._schema_cls = schema_cls
self._durable = durable
self._exclusive = exclusive
self._auto_delete = auto_delete
self._connection = None
self._channel = None
self._consumer_tag = None
self._incoming = queue.Queue()
def _connect(self):
self._connection = pika.BlockingConnection(self._connection_params)
self._channel = self._connection.channel()
# Declare the topic exchange
self._channel.exchange_declare(
exchange=self._exchange_name,
exchange_type='topic',
durable=True,
)
# Declare the queue — anonymous if exclusive
result = self._channel.queue_declare(
queue=self._queue_name,
durable=self._durable,
exclusive=self._exclusive,
auto_delete=self._auto_delete,
)
# Capture actual name (important for anonymous queues where name='')
self._queue_name = result.method.queue
self._channel.queue_bind(
queue=self._queue_name,
exchange=self._exchange_name,
routing_key=self._routing_key,
)
self._channel.basic_qos(prefetch_count=1)
# Register push-based consumer
self._consumer_tag = self._channel.basic_consume(
queue=self._queue_name,
on_message_callback=self._on_message,
auto_ack=False,
)
def _on_message(self, channel, method, properties, body):
"""Callback invoked by pika when a message arrives."""
self._incoming.put((method, properties, body))
def _is_alive(self):
return (
self._connection is not None
and self._connection.is_open
and self._channel is not None
and self._channel.is_open
)
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message. Raises TimeoutError if none available."""
if not self._is_alive():
self._connect()
timeout_seconds = timeout_millis / 1000.0
deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline:
# Check if a message was already delivered
try:
method, properties, body = self._incoming.get_nowait()
return RabbitMQMessage(
method, properties, body, self._schema_cls,
)
except queue.Empty:
pass
# Drive pika's I/O — delivers messages and processes heartbeats
remaining = deadline - time.monotonic()
if remaining > 0:
self._connection.process_data_events(
time_limit=min(0.1, remaining),
)
raise TimeoutError("No message received within timeout")
def acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:
self._channel.basic_ack(
delivery_tag=message._method.delivery_tag,
)
def negative_acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:
self._channel.basic_nack(
delivery_tag=message._method.delivery_tag,
requeue=True,
)
def unsubscribe(self) -> None:
if self._consumer_tag and self._channel and self._channel.is_open:
try:
self._channel.basic_cancel(self._consumer_tag)
except Exception:
pass
self._consumer_tag = None
def close(self) -> None:
self.unsubscribe()
try:
if self._channel and self._channel.is_open:
self._channel.close()
except Exception:
pass
try:
if self._connection and self._connection.is_open:
self._connection.close()
except Exception:
pass
self._channel = None
self._connection = None
class RabbitMQBackend:
"""RabbitMQ pub/sub backend using a topic exchange per topicspace."""
def __init__(self, host='localhost', port=5672, username='guest',
password='guest', vhost='/'):
self._connection_params = pika.ConnectionParameters(
host=host,
port=port,
virtual_host=vhost,
credentials=pika.PlainCredentials(username, password),
heartbeat=0,
)
logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}")
def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]:
"""
Parse queue identifier into exchange, routing key, and durability.
Format: class:topicspace:topic
Returns: (exchange_name, routing_key, class, durable)
"""
if ':' not in queue_id:
return 'tg', queue_id, 'flow', False
parts = queue_id.split(':', 2)
if len(parts) != 3:
raise ValueError(
f"Invalid queue format: {queue_id}, "
f"expected class:topicspace:topic"
)
cls, topicspace, topic = parts
if cls in ('flow', 'state'):
durable = True
elif cls in ('request', 'response'):
durable = False
else:
raise ValueError(
f"Invalid queue class: {cls}, "
f"expected flow, request, response, or state"
)
# Exchange per topicspace, routing key includes class
exchange_name = topicspace
routing_key = f"{cls}.{topic}"
return exchange_name, routing_key, cls, durable
# Keep map_queue_name for backward compatibility with tests
def map_queue_name(self, queue_id: str) -> tuple[str, bool]:
exchange, routing_key, cls, durable = self._parse_queue_id(queue_id)
return f"{exchange}.{routing_key}", durable
def create_producer(self, topic: str, schema: type,
**options) -> BackendProducer:
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
logger.debug(
f"Creating producer: exchange={exchange}, "
f"routing_key={routing_key}"
)
return RabbitMQBackendProducer(
self._connection_params, exchange, routing_key, durable,
)
def create_consumer(self, topic: str, subscription: str, schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options) -> BackendConsumer:
"""Create a consumer with a queue bound to the topic exchange.
consumer_type='shared': Named durable queue. Multiple consumers
with the same subscription compete (round-robin).
consumer_type='exclusive': Anonymous ephemeral queue. Each
consumer gets its own copy of every message (broadcast).
"""
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
if consumer_type == 'exclusive' and cls == 'state':
# State broadcast: named durable queue per subscriber.
# Retains messages so late-starting processors see current state.
queue_name = f"{exchange}.{routing_key}.{subscription}"
queue_durable = True
exclusive = False
auto_delete = False
elif consumer_type == 'exclusive':
# Broadcast: anonymous queue, auto-deleted on disconnect
queue_name = ''
queue_durable = False
exclusive = True
auto_delete = True
else:
# Shared: named queue, competing consumers
queue_name = f"{exchange}.{routing_key}.{subscription}"
queue_durable = durable
exclusive = False
auto_delete = False
logger.debug(
f"Creating consumer: exchange={exchange}, "
f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, "
f"type={consumer_type}"
)
return RabbitMQBackendConsumer(
self._connection_params, exchange, routing_key,
queue_name, schema, queue_durable, exclusive, auto_delete,
)
def close(self) -> None:
pass

View file

@ -0,0 +1,115 @@
"""
JSON serialization helpers for dataclass dict conversion.
Used by pub/sub backends that use JSON as their wire format.
"""
import types
from dataclasses import asdict, is_dataclass
from typing import Any, get_type_hints
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization.
Handles nested dataclasses, lists, and dictionaries recursively.
"""
if obj is None:
return None
# Handle bytes - decode to UTF-8 for JSON serialization
if isinstance(obj, bytes):
return obj.decode('utf-8')
# Handle dataclass - convert to dict then recursively process all values
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
result[key] = dataclass_to_dict(value) if value is not None else None
return result
# Handle list - recursively process all items
if isinstance(obj, list):
return [dataclass_to_dict(item) for item in obj]
# Handle dict - recursively process all values
if isinstance(obj, dict):
return {k: dataclass_to_dict(v) for k, v in obj.items()}
# Return primitive types as-is
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
Uses get_type_hints() to resolve forward references (string annotations).
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass, resolving forward references
# get_type_hints() evaluates string annotations like "Triple | None"
try:
field_types = get_type_hints(cls)
except Exception:
# Fallback if get_type_hints fails (shouldn't happen normally)
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)

View file

@ -7,6 +7,7 @@ import asyncio
import time
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
# Module logger
logger = logging.getLogger(__name__)
@ -38,6 +39,7 @@ class Subscriber:
self.pending_acks = {} # Track messages awaiting delivery
self.consumer = None
self.executor = None
def __del__(self):
@ -45,15 +47,6 @@ class Subscriber:
async def start(self):
# Create consumer via backend
self.consumer = await asyncio.to_thread(
self.backend.create_consumer,
topic=self.topic,
subscription=self.subscription,
schema=self.schema,
consumer_type='shared',
)
self.task = asyncio.create_task(self.run())
async def stop(self):
@ -80,6 +73,21 @@ class Subscriber:
try:
# Create consumer and dedicated thread if needed
# (first run or after failure)
if self.consumer is None:
self.executor = ThreadPoolExecutor(max_workers=1)
loop = asyncio.get_event_loop()
self.consumer = await loop.run_in_executor(
self.executor,
lambda: self.backend.create_consumer(
topic=self.topic,
subscription=self.subscription,
schema=self.schema,
consumer_type='exclusive',
),
)
if self.metrics:
self.metrics.state("running")
@ -128,9 +136,12 @@ class Subscriber:
# Process messages only if not draining
if not self.draining:
try:
msg = await asyncio.to_thread(
self.consumer.receive,
timeout_millis=250
loop = asyncio.get_event_loop()
msg = await loop.run_in_executor(
self.executor,
lambda: self.consumer.receive(
timeout_millis=250
),
)
except Exception as e:
# Handle timeout from any backend
@ -172,15 +183,18 @@ class Subscriber:
except Exception:
pass # Already closed or error
self.consumer = None
if self.executor:
self.executor.shutdown(wait=False)
self.executor = None
if self.metrics:
self.metrics.state("stopped")
if not self.running and not self.draining:
return
# If handler drops out, sleep a retry
# Sleep before retry
await asyncio.sleep(1)
async def subscribe(self, id):

View file

@ -1,5 +1,4 @@
import _pulsar
from .. schema import AgentRequest, AgentResponse
from .. schema import agent_request_queue
@ -7,15 +6,11 @@ from .. schema import agent_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class AgentClient(BaseClient):
def __init__(
self, log_level=ERROR,
self,
subscriber=None,
input_queue=None,
output_queue=None,
@ -27,7 +22,6 @@ class AgentClient(BaseClient):
if output_queue is None: output_queue = agent_response_queue
super(AgentClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,

Some files were not shown because too many files have changed in this diff Show more