mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
fix: forward explain_triples through RAG clients and agent tool callback - RAG clients and the KnowledgeQueryImpl tool callback were dropping explain_triples from explain events, losing provenance data (including focus edge selections) when graph-rag is invoked via the agent. Tests for provenance and explainability (56 new): - Client-level forwarding of explain_triples - Graph-RAG structural chain (question → grounding → exploration → focus → synthesis) - Graph-RAG integration with mocked subsidiary clients - Document-RAG integration (question → grounding → exploration → synthesis) - Agent-orchestrator all 3 patterns: react, plan-then-execute, supervisor
84 lines
2.5 KiB
Python
84 lines
2.5 KiB
Python
|
|
|
|
from .. schema import GraphRagQuery, GraphRagResponse
|
|
from .. schema import graph_rag_request_queue, graph_rag_response_queue
|
|
from . base import BaseClient
|
|
|
|
# Ugly
|
|
|
|
class GraphRagClient(BaseClient):
|
|
|
|
def __init__(
|
|
self,
|
|
subscriber=None,
|
|
input_queue=None,
|
|
output_queue=None,
|
|
pulsar_host="pulsar://pulsar:6650",
|
|
pulsar_api_key=None,
|
|
):
|
|
|
|
if input_queue == None:
|
|
input_queue = graph_rag_request_queue
|
|
|
|
if output_queue == None:
|
|
output_queue = graph_rag_response_queue
|
|
|
|
super(GraphRagClient, self).__init__(
|
|
subscriber=subscriber,
|
|
input_queue=input_queue,
|
|
output_queue=output_queue,
|
|
pulsar_host=pulsar_host,
|
|
pulsar_api_key=pulsar_api_key,
|
|
input_schema=GraphRagQuery,
|
|
output_schema=GraphRagResponse,
|
|
)
|
|
|
|
def request(
|
|
self, query, user="trustgraph", collection="default",
|
|
chunk_callback=None,
|
|
explain_callback=None,
|
|
timeout=500
|
|
):
|
|
"""
|
|
Request a graph RAG query with optional streaming callbacks.
|
|
|
|
Args:
|
|
query: The question to ask
|
|
user: User identifier
|
|
collection: Collection identifier
|
|
chunk_callback: Optional callback(text, end_of_stream) for text chunks
|
|
explain_callback: Optional callback(explain_id, explain_graph, explain_triples) for explain notifications
|
|
timeout: Request timeout in seconds
|
|
|
|
Returns:
|
|
Complete response text (accumulated from all chunks)
|
|
"""
|
|
accumulated_response = []
|
|
|
|
def inspect(x):
|
|
# Handle explain notifications
|
|
if x.message_type == 'explain':
|
|
if explain_callback and x.explain_id:
|
|
explain_callback(x.explain_id, x.explain_graph, x.explain_triples)
|
|
return False # Continue receiving
|
|
|
|
# Handle text chunks
|
|
if x.message_type == 'chunk':
|
|
if x.response:
|
|
accumulated_response.append(x.response)
|
|
if chunk_callback:
|
|
chunk_callback(x.response, x.end_of_stream)
|
|
|
|
# Complete when session ends
|
|
if x.end_of_session:
|
|
return True
|
|
|
|
return False # Continue receiving
|
|
|
|
self.call(
|
|
user=user, collection=collection, query=query,
|
|
inspect=inspect, timeout=timeout
|
|
)
|
|
|
|
return "".join(accumulated_response)
|
|
|