trustgraph/trustgraph-base/trustgraph/clients/graph_rag_client.py
cybermaggedon 4b5bfacab1
Forward missing explain_triples through RAG clients and agent tool callback (#768)
fix: forward explain_triples through RAG clients and agent tool callback
- RAG clients and the KnowledgeQueryImpl tool callback were
  dropping explain_triples from explain events, losing provenance
  data (including focus edge selections) when graph-rag is invoked
  via the agent.

Tests for provenance and explainability (56 new):
- Client-level forwarding of explain_triples
- Graph-RAG structural chain
  (question → grounding → exploration → focus → synthesis)
- Graph-RAG integration with mocked subsidiary clients
- Document-RAG integration
  (question → grounding → exploration → synthesis)
- Agent-orchestrator all 3 patterns: react, plan-then-execute,
  supervisor
2026-04-08 11:41:17 +01:00

84 lines
2.5 KiB
Python

from .. schema import GraphRagQuery, GraphRagResponse
from .. schema import graph_rag_request_queue, graph_rag_response_queue
from . base import BaseClient
# Ugly
class GraphRagClient(BaseClient):
def __init__(
self,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
):
if input_queue == None:
input_queue = graph_rag_request_queue
if output_queue == None:
output_queue = graph_rag_response_queue
super(GraphRagClient, self).__init__(
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_schema=GraphRagQuery,
output_schema=GraphRagResponse,
)
def request(
self, query, user="trustgraph", collection="default",
chunk_callback=None,
explain_callback=None,
timeout=500
):
"""
Request a graph RAG query with optional streaming callbacks.
Args:
query: The question to ask
user: User identifier
collection: Collection identifier
chunk_callback: Optional callback(text, end_of_stream) for text chunks
explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
timeout: Request timeout in seconds
Returns:
Complete response text (accumulated from all chunks)
"""
accumulated_response = []
def inspect(x):
# Handle explain notifications
if x.message_type == 'explain':
if explain_callback and x.explain_id:
explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
return False # Continue receiving
# Handle text chunks
if x.message_type == 'chunk':
if x.response:
accumulated_response.append(x.response)
if chunk_callback:
chunk_callback(x.response, x.end_of_stream)
# Complete when session ends
if x.end_of_session:
return True
return False # Continue receiving
self.call(
user=user, collection=collection, query=query,
inspect=inspect, timeout=timeout
)
return "".join(accumulated_response)