mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
Forward missing explain_triples through RAG clients and agent tool callback (#768)
fix: forward explain_triples through RAG clients and agent tool callback - RAG clients and the KnowledgeQueryImpl tool callback were dropping explain_triples from explain events, losing provenance data (including focus edge selections) when graph-rag is invoked via the agent. Tests for provenance and explainability (56 new): - Client-level forwarding of explain_triples - Graph-RAG structural chain (question → grounding → exploration → focus → synthesis) - Graph-RAG integration with mocked subsidiary clients - Document-RAG integration (question → grounding → exploration → synthesis) - Agent-orchestrator all 3 patterns: react, plan-then-execute, supervisor
This commit is contained in:
parent
e899370d98
commit
4b5bfacab1
9 changed files with 2178 additions and 7 deletions
|
|
@ -0,0 +1,380 @@
|
|||
"""
|
||||
Integration test: run a full DocumentRag.query() with mocked subsidiary
|
||||
clients and verify the explain_callback receives the complete provenance
|
||||
chain in the correct order with correct structure.
|
||||
|
||||
Document-RAG provenance chain (4 stages):
|
||||
question → grounding → exploration → synthesis
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT,
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMatch:
|
||||
"""Mimics the result from doc_embeddings_client.query()."""
|
||||
chunk_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
|
||||
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
|
||||
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
|
||||
CHUNK_B_CONTENT = "Refunds are processed to the original payment method."
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock clients for a document-rag query.
|
||||
|
||||
Client call sequence during query():
|
||||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. doc_embeddings_client.query(vector, ...) -> chunk matches
|
||||
4. fetch_chunk(chunk_id, user) -> chunk content
|
||||
5. prompt_client.document_prompt(query, documents) -> answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return "return policy\nrefund"
|
||||
return ""
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
# 2. Embedding vectors
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# 3. Chunk matching
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id=CHUNK_A),
|
||||
ChunkMatch(chunk_id=CHUNK_B),
|
||||
]
|
||||
|
||||
# 4. Chunk content
|
||||
async def mock_fetch(chunk_id, user):
|
||||
return {
|
||||
CHUNK_A: CHUNK_A_CONTENT,
|
||||
CHUNK_B: CHUNK_B_CONTENT,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
# 5. Synthesis
|
||||
prompt_client.document_prompt.return_value = (
|
||||
"Items can be returned within 30 days for a full refund."
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentRagQueryProvenance:
|
||||
"""
|
||||
Run a real DocumentRag.query() and verify the provenance chain emitted
|
||||
via explain_callback.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_four_events(self):
|
||||
"""query() should emit exactly 4 explain events."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert len(events) == 4, (
|
||||
f"Expected 4 explain events (question, grounding, exploration, "
|
||||
f"synthesis), got {len(events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_events_have_correct_types_in_order(self):
|
||||
"""
|
||||
Events should arrive as:
|
||||
question, grounding, exploration, synthesis.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_SYNTHESIS,
|
||||
]
|
||||
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
uri = events[i]["explain_id"]
|
||||
triples = events[i]["triples"]
|
||||
assert has_type(triples, uri, expected_type), (
|
||||
f"Event {i} (uri={uri}) should have type {expected_type}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derivation_chain_links_correctly(self):
|
||||
"""
|
||||
Each event's URI should link to the previous via wasDerivedFrom:
|
||||
question → (none)
|
||||
grounding → question
|
||||
exploration → grounding
|
||||
synthesis → exploration
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
# question has no parent
|
||||
assert derived_from(all_triples, uris[0]) is None
|
||||
|
||||
# grounding → question
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
|
||||
# exploration → grounding
|
||||
assert derived_from(all_triples, uris[2]) == uris[1]
|
||||
|
||||
# synthesis → exploration
|
||||
assert derived_from(all_triples, uris[3]) == uris[2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_carries_query_text(self):
|
||||
"""The question event should contain the original query string."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
t = find_triple(q_triples, TG_QUERY, q_uri)
|
||||
assert t is not None
|
||||
assert t.o.value == "What is the return policy?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_carries_concepts(self):
|
||||
"""The grounding event should list extracted concepts."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
gnd_triples = events[1]["triples"]
|
||||
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert "return policy" in concept_values
|
||||
assert "refund" in concept_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_chunk_count(self):
|
||||
"""The exploration event should report the number of chunks retrieved."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
t = find_triple(exp_triples, TG_CHUNK_COUNT, exp_uri)
|
||||
assert t is not None
|
||||
assert int(t.o.value) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_selected_chunks(self):
|
||||
"""The exploration event should list the chunk IDs that were fetched."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
chunks = find_triples(exp_triples, TG_SELECTED_CHUNK, exp_uri)
|
||||
chunk_iris = {t.o.iri for t in chunks}
|
||||
assert CHUNK_A in chunk_iris
|
||||
assert CHUNK_B in chunk_iris
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
"""The synthesis event should have tg:Synthesis and tg:Answer types."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
syn_uri = events[3]["explain_id"]
|
||||
syn_triples = events[3]["triples"]
|
||||
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
|
||||
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_answer_text(self):
|
||||
"""query() should return the synthesised answer."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
"""query() without explain_callback should return answer normally."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(query="What is the return policy?")
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All emitted triples should be in the urn:graph:retrieval graph."""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
for t in event["triples"]:
|
||||
assert t.g == "urn:graph:retrieval", (
|
||||
f"Triple {t.s.iri} {t.p.iri} should be in "
|
||||
f"urn:graph:retrieval, got {t.g}"
|
||||
)
|
||||
358
tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py
Normal file
358
tests/unit/test_retrieval/test_graph_rag_explain_forwarding.py
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
"""
|
||||
Tests that explain_triples are forwarded correctly through the graph-rag
|
||||
service and client layers.
|
||||
|
||||
Covers:
|
||||
- Service: explain messages include triples from the provenance callback
|
||||
- Client: explain_callback receives explain_triples from the response
|
||||
- End-to-end: triples survive the full service → client → callback chain
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
GraphRagQuery, GraphRagResponse,
|
||||
Triple, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base.graph_rag_client import GraphRagClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_triple(s_iri, p_iri, o_value, o_type=IRI):
|
||||
"""Create a Triple with IRI subject/predicate and typed object."""
|
||||
o = (
|
||||
Term(type=IRI, iri=o_value) if o_type == IRI
|
||||
else Term(type=LITERAL, value=o_value)
|
||||
)
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=s_iri),
|
||||
p=Term(type=IRI, iri=p_iri),
|
||||
o=o,
|
||||
)
|
||||
|
||||
|
||||
def sample_focus_triples():
|
||||
"""Focus-style triples with a quoted triple (edge selection)."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/Focus",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"http://www.w3.org/ns/prov#wasDerivedFrom",
|
||||
"urn:trustgraph:exploration:abc",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:focus:abc",
|
||||
"https://trustgraph.ai/ns/selectedEdge",
|
||||
"urn:trustgraph:edge-sel:abc:0",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_question_triples():
|
||||
"""Question-style triples."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/GraphRagQuestion",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc",
|
||||
"https://trustgraph.ai/ns/query",
|
||||
"What is quantum computing?",
|
||||
o_type=LITERAL,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service-level: explain messages carry triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagServiceExplainTriples:
|
||||
"""Test that the graph-rag service includes explain_triples in messages."""
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_messages_include_triples(self, mock_graph_rag_class):
|
||||
"""
|
||||
When the provenance callback is invoked with triples, the service
|
||||
should include them in the explain response message.
|
||||
"""
|
||||
from trustgraph.retrieval.graph_rag.rag import Processor
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
)
|
||||
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
|
||||
question_triples = sample_question_triples()
|
||||
focus_triples = sample_focus_triples()
|
||||
|
||||
async def mock_query(**kwargs):
|
||||
explain_callback = kwargs.get('explain_callback')
|
||||
if explain_callback:
|
||||
await explain_callback(
|
||||
question_triples, "urn:trustgraph:question:abc"
|
||||
)
|
||||
await explain_callback(
|
||||
focus_triples, "urn:trustgraph:focus:abc"
|
||||
)
|
||||
return "The answer."
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is quantum computing?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_provenance = AsyncMock()
|
||||
|
||||
def flow_router(name):
|
||||
if name == "response":
|
||||
return mock_response
|
||||
if name == "explainability":
|
||||
return mock_provenance
|
||||
return AsyncMock()
|
||||
|
||||
flow.side_effect = flow_router
|
||||
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Find the explain messages
|
||||
explain_msgs = [
|
||||
call[0][0]
|
||||
for call in mock_response.send.call_args_list
|
||||
if call[0][0].message_type == "explain"
|
||||
]
|
||||
|
||||
assert len(explain_msgs) == 2
|
||||
|
||||
# First explain message should carry question triples
|
||||
assert explain_msgs[0].explain_id == "urn:trustgraph:question:abc"
|
||||
assert explain_msgs[0].explain_triples == question_triples
|
||||
|
||||
# Second explain message should carry focus triples
|
||||
assert explain_msgs[1].explain_id == "urn:trustgraph:focus:abc"
|
||||
assert explain_msgs[1].explain_triples == focus_triples
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client-level: explain_callback receives triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagClientExplainForwarding:
|
||||
"""Test that GraphRagClient.rag() forwards explain_triples to callback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_triples(self):
|
||||
"""
|
||||
The explain_callback should receive (explain_id, explain_graph,
|
||||
explain_triples) — not just (explain_id, explain_graph).
|
||||
"""
|
||||
focus_triples = sample_focus_triples()
|
||||
|
||||
# Simulate the response sequence the client would receive
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:focus:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=focus_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="The answer.",
|
||||
end_of_stream=True,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Capture what the explain_callback receives
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append({
|
||||
"explain_id": explain_id,
|
||||
"explain_graph": explain_graph,
|
||||
"explain_triples": explain_triples,
|
||||
})
|
||||
|
||||
# Patch self.request to feed responses to the recipient
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
result = await client.rag(
|
||||
query="test",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert result == "The answer."
|
||||
assert len(received_calls) == 1
|
||||
assert received_calls[0]["explain_id"] == "urn:trustgraph:focus:abc"
|
||||
assert received_calls[0]["explain_graph"] == "urn:graph:retrieval"
|
||||
assert received_calls[0]["explain_triples"] == focus_triples
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_empty_triples(self):
|
||||
"""
|
||||
When an explain event has no triples, the callback should still
|
||||
receive an empty list (not None or missing).
|
||||
"""
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=[],
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append(explain_triples)
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
await client.rag(query="test", explain_callback=explain_callback)
|
||||
|
||||
assert len(received_calls) == 1
|
||||
assert received_calls[0] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_explain_events_all_forward_triples(self):
|
||||
"""
|
||||
Each explain event in a session should forward its own triples.
|
||||
"""
|
||||
q_triples = sample_question_triples()
|
||||
f_triples = sample_focus_triples()
|
||||
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=q_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:focus:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=f_triples,
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
received_calls = []
|
||||
|
||||
async def explain_callback(explain_id, explain_graph, explain_triples):
|
||||
received_calls.append({
|
||||
"explain_id": explain_id,
|
||||
"explain_triples": explain_triples,
|
||||
})
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
await client.rag(query="test", explain_callback=explain_callback)
|
||||
|
||||
assert len(received_calls) == 2
|
||||
assert received_calls[0]["explain_id"] == "urn:trustgraph:question:abc"
|
||||
assert received_calls[0]["explain_triples"] == q_triples
|
||||
assert received_calls[1]["explain_id"] == "urn:trustgraph:focus:abc"
|
||||
assert received_calls[1]["explain_triples"] == f_triples
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_does_not_error(self):
|
||||
"""
|
||||
When no explain_callback is provided, explain events should be
|
||||
silently skipped without errors.
|
||||
"""
|
||||
responses = [
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_question_triples(),
|
||||
),
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Answer.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
]
|
||||
|
||||
client = GraphRagClient.__new__(GraphRagClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for resp in responses:
|
||||
done = await recipient(resp)
|
||||
if done:
|
||||
return resp
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
result = await client.rag(query="test")
|
||||
assert result == "Answer."
|
||||
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Integration test: run a full GraphRag.query() with mocked subsidiary clients
|
||||
and verify the explain_callback receives the complete provenance chain
|
||||
in the correct order with correct structure.
|
||||
|
||||
This tests the real query() method end-to-end, not just the triple builders.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
|
||||
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingMatch:
|
||||
"""Mimics the result from graph_embeddings_client.query()."""
|
||||
entity: Term
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# A tiny knowledge graph: 2 entities, 3 edges
|
||||
ENTITY_A = "http://example.com/QuantumComputing"
|
||||
ENTITY_B = "http://example.com/Physics"
|
||||
EDGE_1 = (ENTITY_A, "http://schema.org/relatedTo", ENTITY_B)
|
||||
EDGE_2 = (ENTITY_A, "http://schema.org/name", "Quantum Computing")
|
||||
EDGE_3 = (ENTITY_B, "http://schema.org/name", "Physics")
|
||||
|
||||
|
||||
def make_schema_triple(s, p, o):
|
||||
"""Create a SchemaTriple from string values."""
|
||||
return SchemaTriple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o) if o.startswith("http") else Term(type=LITERAL, value=o),
|
||||
)
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock clients that simulate a small knowledge graph query.
|
||||
|
||||
Client call sequence during query():
|
||||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. graph_embeddings_client.query(vector, ...) -> entity matches
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
|
||||
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
|
||||
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
|
||||
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
|
||||
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
prompt_responses = {}
|
||||
prompt_responses["extract-concepts"] = "quantum computing\nphysics"
|
||||
|
||||
# 2. Embedding vectors (simple fake vectors)
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# 3. Entity lookup - return our two entities
|
||||
graph_embeddings_client.query.return_value = [
|
||||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_A)),
|
||||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
|
||||
]
|
||||
|
||||
# 4. Triple queries (follow_edges_batch) - return our edges
|
||||
kg_triples = [
|
||||
make_schema_triple(*EDGE_1),
|
||||
make_schema_triple(*EDGE_2),
|
||||
make_schema_triple(*EDGE_3),
|
||||
]
|
||||
triples_client.query_stream.return_value = kg_triples
|
||||
|
||||
# 5. Label resolution - return entity as its own label (simplify)
|
||||
async def mock_label_query(s=None, p=None, o=None, limit=1,
|
||||
user=None, collection=None, g=None):
|
||||
return [] # No labels found, will fall back to URI
|
||||
triples_client.query.side_effect = mock_label_query
|
||||
|
||||
# 6+7. Edge scoring and reasoning: dynamically score/reason about
|
||||
# whatever edges the query method sends us, since edge IDs are computed
|
||||
# from str(Term) representations which include the full dataclass repr.
|
||||
synthesis_answer = "Quantum computing applies physics principles to computation."
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return prompt_responses["extract-concepts"]
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
elif template_id == "kg-synthesis":
|
||||
return synthesis_answer
|
||||
return ""
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagQueryProvenance:
|
||||
"""
|
||||
Run a real GraphRag.query() and verify the provenance chain emitted
|
||||
via explain_callback.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explain_callback_receives_five_events(self):
|
||||
"""query() should emit exactly 5 explain events."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0, # skip semantic pre-filter for simplicity
|
||||
)
|
||||
|
||||
assert len(events) == 5, (
|
||||
f"Expected 5 explain events (question, grounding, exploration, "
|
||||
f"focus, synthesis), got {len(events)}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_events_have_correct_types_in_order(self):
|
||||
"""
|
||||
Events should arrive as:
|
||||
question, grounding, exploration, focus, synthesis.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
TG_GRAPH_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_FOCUS,
|
||||
TG_SYNTHESIS,
|
||||
]
|
||||
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
uri = events[i]["explain_id"]
|
||||
triples = events[i]["triples"]
|
||||
assert has_type(triples, uri, expected_type), (
|
||||
f"Event {i} (uri={uri}) should have type {expected_type}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_derivation_chain_links_correctly(self):
|
||||
"""
|
||||
Each event's URI should link to the previous via wasDerivedFrom:
|
||||
grounding → question → (none)
|
||||
exploration → grounding
|
||||
focus → exploration
|
||||
synthesis → focus
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
all_triples = []
|
||||
for e in events:
|
||||
all_triples.extend(e["triples"])
|
||||
|
||||
# question has no parent
|
||||
assert derived_from(all_triples, uris[0]) is None
|
||||
|
||||
# grounding → question
|
||||
assert derived_from(all_triples, uris[1]) == uris[0]
|
||||
|
||||
# exploration → grounding
|
||||
assert derived_from(all_triples, uris[2]) == uris[1]
|
||||
|
||||
# focus → exploration
|
||||
assert derived_from(all_triples, uris[3]) == uris[2]
|
||||
|
||||
# synthesis → focus
|
||||
assert derived_from(all_triples, uris[4]) == uris[3]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_event_carries_query_text(self):
|
||||
"""The question event should contain the original query string."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
t = find_triple(q_triples, TG_QUERY, q_uri)
|
||||
assert t is not None
|
||||
assert t.o.value == "What is quantum computing?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grounding_carries_concepts(self):
|
||||
"""The grounding event should list extracted concepts."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
gnd_triples = events[1]["triples"]
|
||||
concepts = find_triples(gnd_triples, TG_CONCEPT, gnd_uri)
|
||||
concept_values = {t.o.value for t in concepts}
|
||||
assert "quantum computing" in concept_values
|
||||
assert "physics" in concept_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exploration_has_edge_count(self):
|
||||
"""The exploration event should report how many edges were found."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
exp_triples = events[2]["triples"]
|
||||
t = find_triple(exp_triples, TG_EDGE_COUNT, exp_uri)
|
||||
assert t is not None
|
||||
# Should be non-zero (we provided 3 edges, label edges filtered)
|
||||
assert int(t.o.value) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_focus_has_selected_edges_with_reasoning(self):
|
||||
"""
|
||||
The focus event should carry selected edges as quoted triples
|
||||
with reasoning text.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
foc_uri = events[3]["explain_id"]
|
||||
foc_triples = events[3]["triples"]
|
||||
|
||||
# Should have selected edges
|
||||
selected = find_triples(foc_triples, TG_SELECTED_EDGE, foc_uri)
|
||||
assert len(selected) > 0, "Focus should have at least one selected edge"
|
||||
|
||||
# Each edge selection should have a quoted triple
|
||||
edge_t = find_triples(foc_triples, TG_EDGE)
|
||||
assert len(edge_t) > 0, "Focus should have tg:edge with quoted triples"
|
||||
for t in edge_t:
|
||||
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
|
||||
|
||||
# Should have reasoning
|
||||
reasoning = find_triples(foc_triples, TG_REASONING)
|
||||
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
"""The synthesis event should have tg:Answer type."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
syn_uri = events[4]["explain_id"]
|
||||
syn_triples = events[4]["triples"]
|
||||
assert has_type(syn_triples, syn_uri, TG_SYNTHESIS)
|
||||
assert has_type(syn_triples, syn_uri, TG_ANSWER_TYPE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_answer_text(self):
|
||||
"""query() should still return the synthesised answer."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
result = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_uri_links_question_to_parent(self):
|
||||
"""When parent_uri is provided, question should derive from it."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
parent = "urn:trustgraph:agent:iteration:xyz"
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
parent_uri=parent,
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
q_triples = events[0]["triples"]
|
||||
assert derived_from(q_triples, q_uri) == parent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
"""query() without explain_callback should return answer normally."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
"""All emitted triples should be in the urn:graph:retrieval graph."""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
for t in event["triples"]:
|
||||
assert t.g == "urn:graph:retrieval", (
|
||||
f"Triple {t.s.iri} {t.p.iri} should be in "
|
||||
f"urn:graph:retrieval, got {t.g}"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue