Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding,

consistent PROV-O

GraphRAG:
- Split retrieval into 4 prompt stages: extract-concepts,
  kg-edge-scoring,
  kg-edge-reasoning, kg-synthesis (was single-stage)
- Add concept extraction (grounding) for per-concept embedding
- Filter main query to default graph, ignoring
  provenance/explainability edges
- Add source document edges to knowledge graph

DocumentRAG:
- Add grounding step with concept extraction, matching GraphRAG's
  pattern:
  Question → Grounding → Exploration → Synthesis
- Per-concept embedding and chunk retrieval with deduplication

Cross-pipeline:
- Make PROV-O derivation links consistent: wasGeneratedBy for first
  entity from Activity, wasDerivedFrom for entity-to-entity chains
- Update CLIs (tg-invoke-agent, tg-invoke-graph-rag,
  tg-invoke-document-rag) for new explainability structure
- Fix all affected unit and integration tests
This commit is contained in:
Cyber MacGeddon 2026-03-14 11:54:10 +00:00
parent 29b4300808
commit 20bb645b9a
25 changed files with 1537 additions and 1008 deletions

View file

@ -86,13 +86,18 @@ class TestGraphRagIntegration:
"""Mock prompt client that generates realistic responses for two-step process""" """Mock prompt client that generates realistic responses for two-step process"""
client = AsyncMock() client = AsyncMock()
# Mock responses for the two-step process: # Mock responses for the multi-step process:
# 1. kg-edge-selection returns JSONL with edge IDs # 1. extract-concepts extracts key concepts from the query
# 2. kg-synthesis returns the final answer # 2. kg-edge-scoring scores edges for relevance
# 3. kg-edge-reasoning provides reasoning for selected edges
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection": if prompt_name == "extract-concepts":
# Return empty selection (no edges selected) - valid JSONL return "" # Falls back to raw query
return "" elif prompt_name == "kg-edge-scoring":
return "" # No edges scored
elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning
elif prompt_name == "kg-synthesis": elif prompt_name == "kg-synthesis":
return ( return (
"Machine learning is a subset of artificial intelligence that enables computers " "Machine learning is a subset of artificial intelligence that enables computers "
@ -160,16 +165,16 @@ class TestGraphRagIntegration:
# 3. Should query triples to build knowledge subgraph # 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query_stream.call_count > 0 assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt twice (edge selection + synthesis) # 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
assert mock_prompt_client.prompt.call_count == 2 assert mock_prompt_client.prompt.call_count == 4
# Verify final response # Verify final response
assert response is not None assert response is not None
assert isinstance(response, str) assert isinstance(response, str)
assert "machine learning" in response.lower() assert "machine learning" in response.lower()
# Verify provenance was emitted in real-time (4 events: question, exploration, focus, synthesis) # Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
assert len(provenance_events) == 4 assert len(provenance_events) == 5
for triples, prov_id in provenance_events: for triples, prov_id in provenance_events:
assert isinstance(triples, list) assert isinstance(triples, list)
assert prov_id.startswith("urn:trustgraph:") assert prov_id.startswith("urn:trustgraph:")
@ -243,10 +248,10 @@ class TestGraphRagIntegration:
) )
# Assert # Assert
# Should still call prompt client (twice: edge selection + synthesis) # Should still call prompt client
assert response is not None assert response is not None
# Provenance should still be emitted (4 events) # Provenance should still be emitted (5 events)
assert len(provenance_events) == 4 assert len(provenance_events) == 5
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client): async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):

View file

@ -60,8 +60,12 @@ class TestGraphRagStreaming:
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "kg-edge-selection": if prompt_id == "extract-concepts":
# Edge selection returns JSONL with IDs - simulate selecting first edge return "" # Falls back to raw query
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n'
elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n' return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
elif prompt_id == "kg-synthesis": elif prompt_id == "kg-synthesis":
if streaming and chunk_callback: if streaming and chunk_callback:
@ -132,8 +136,8 @@ class TestGraphRagStreaming:
# Verify content is reasonable # Verify content is reasonable
assert "machine" in response.lower() or "learning" in response.lower() assert "machine" in response.lower() or "learning" in response.lower()
# Verify provenance was emitted in real-time (4 events) # Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
assert len(provenance_events) == 4 assert len(provenance_events) == 5
for triples, prov_id in provenance_events: for triples, prov_id in provenance_events:
assert prov_id.startswith("urn:trustgraph:") assert prov_id.startswith("urn:trustgraph:")

View file

@ -15,10 +15,11 @@ from trustgraph.provenance.agent import (
from trustgraph.provenance.namespaces import ( from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL, RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME, PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_AGENT_QUESTION, TG_AGENT_QUESTION,
) )
@ -110,84 +111,107 @@ class TestAgentSessionTriples:
class TestAgentIterationTriples: class TestAgentIterationTriples:
ITER_URI = "urn:trustgraph:agent:test-session/i1" ITER_URI = "urn:trustgraph:agent:test-session/i1"
PARENT_URI = "urn:trustgraph:agent:test-session" SESSION_URI = "urn:trustgraph:agent:test-session"
PREV_URI = "urn:trustgraph:agent:test-session/i0"
def test_iteration_types(self): def test_iteration_types(self):
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
thought="thinking", action="search", observation="found it", action="search",
) )
assert has_type(triples, self.ITER_URI, PROV_ENTITY) assert has_type(triples, self.ITER_URI, PROV_ENTITY)
assert has_type(triples, self.ITER_URI, TG_ANALYSIS) assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
def test_iteration_derived_from_parent(self): def test_first_iteration_generated_by_question(self):
"""First iteration uses wasGeneratedBy to link to question activity."""
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
# Should NOT have wasDerivedFrom
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is None
def test_subsequent_iteration_derived_from_previous(self):
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
triples = agent_iteration_triples(
self.ITER_URI, previous_uri=self.PREV_URI,
action="search", action="search",
) )
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is not None assert derived is not None
assert derived.o.iri == self.PARENT_URI assert derived.o.iri == self.PREV_URI
# Should NOT have wasGeneratedBy
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is None
def test_iteration_label_includes_action(self): def test_iteration_label_includes_action(self):
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query", action="graph-rag-query",
) )
label = find_triple(triples, RDFS_LABEL, self.ITER_URI) label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
assert label is not None assert label is not None
assert "graph-rag-query" in label.o.value assert "graph-rag-query" in label.o.value
def test_iteration_thought_inline(self): def test_iteration_thought_sub_entity(self):
"""Thought is a sub-entity with Reflection and Thought types."""
thought_uri = "urn:trustgraph:agent:test-session/i1/thought"
thought_doc = "urn:doc:thought-1"
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
thought="I need to search for info",
action="search", action="search",
thought_uri=thought_uri,
thought_document_id=thought_doc,
) )
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) # Iteration links to thought sub-entity
assert thought is not None thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought.o.value == "I need to search for info" assert thought_link is not None
assert thought_link.o.iri == thought_uri
# Thought has correct types
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
# Thought was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Thought has document reference
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
assert doc is not None
assert doc.o.iri == thought_doc
def test_iteration_thought_document_preferred(self): def test_iteration_observation_sub_entity(self):
"""When thought_document_id is provided, inline thought is not stored.""" """Observation is a sub-entity with Reflection and Observation types."""
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
obs_doc = "urn:doc:obs-1"
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
thought="inline thought",
action="search", action="search",
thought_document_id="urn:doc:thought-1", observation_uri=obs_uri,
observation_document_id=obs_doc,
) )
thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI) # Iteration links to observation sub-entity
assert thought_doc is not None obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert thought_doc.o.iri == "urn:doc:thought-1" assert obs_link is not None
thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI) assert obs_link.o.iri == obs_uri
assert thought_inline is None # Observation has correct types
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
def test_iteration_observation_inline(self): assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
triples = agent_iteration_triples( # Observation was generated by iteration
self.ITER_URI, self.PARENT_URI, gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
action="search", assert gen is not None
observation="Found 3 results", assert gen.o.iri == self.ITER_URI
) # Observation has document reference
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI) doc = find_triple(triples, TG_DOCUMENT, obs_uri)
assert obs is not None assert doc is not None
assert obs.o.value == "Found 3 results" assert doc.o.iri == obs_doc
def test_iteration_observation_document_preferred(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
observation="inline obs",
observation_document_id="urn:doc:obs-1",
)
obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI)
assert obs_doc is not None
assert obs_doc.o.iri == "urn:doc:obs-1"
obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_inline is None
def test_iteration_action_recorded(self): def test_iteration_action_recorded(self):
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query", action="graph-rag-query",
) )
action = find_triple(triples, TG_ACTION, self.ITER_URI) action = find_triple(triples, TG_ACTION, self.ITER_URI)
@ -197,7 +221,7 @@ class TestAgentIterationTriples:
def test_iteration_arguments_json_encoded(self): def test_iteration_arguments_json_encoded(self):
args = {"query": "test query", "limit": 10} args = {"query": "test query", "limit": 10}
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
action="search", action="search",
arguments=args, arguments=args,
) )
@ -208,7 +232,7 @@ class TestAgentIterationTriples:
def test_iteration_default_arguments_empty_dict(self): def test_iteration_default_arguments_empty_dict(self):
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
action="search", action="search",
) )
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI) arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
@ -219,7 +243,7 @@ class TestAgentIterationTriples:
def test_iteration_no_thought_or_observation(self): def test_iteration_no_thought_or_observation(self):
"""Minimal iteration with just action — no thought or observation triples.""" """Minimal iteration with just action — no thought or observation triples."""
triples = agent_iteration_triples( triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI, self.ITER_URI, question_uri=self.SESSION_URI,
action="noop", action="noop",
) )
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
@ -228,19 +252,19 @@ class TestAgentIterationTriples:
assert obs is None assert obs is None
def test_iteration_chaining(self): def test_iteration_chaining(self):
"""Second iteration derives from first iteration, not session.""" """First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
iter1_uri = "urn:trustgraph:agent:sess/i1" iter1_uri = "urn:trustgraph:agent:sess/i1"
iter2_uri = "urn:trustgraph:agent:sess/i2" iter2_uri = "urn:trustgraph:agent:sess/i2"
triples1 = agent_iteration_triples( triples1 = agent_iteration_triples(
iter1_uri, self.PARENT_URI, action="step1", iter1_uri, question_uri=self.SESSION_URI, action="step1",
) )
triples2 = agent_iteration_triples( triples2 = agent_iteration_triples(
iter2_uri, iter1_uri, action="step2", iter2_uri, previous_uri=iter1_uri, action="step2",
) )
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri) gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
assert derived1.o.iri == self.PARENT_URI assert gen1.o.iri == self.SESSION_URI
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri) derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
assert derived2.o.iri == iter1_uri assert derived2.o.iri == iter1_uri
@ -253,42 +277,50 @@ class TestAgentIterationTriples:
class TestAgentFinalTriples: class TestAgentFinalTriples:
FINAL_URI = "urn:trustgraph:agent:test-session/final" FINAL_URI = "urn:trustgraph:agent:test-session/final"
PARENT_URI = "urn:trustgraph:agent:test-session/i3" PREV_URI = "urn:trustgraph:agent:test-session/i3"
SESSION_URI = "urn:trustgraph:agent:test-session"
def test_final_types(self): def test_final_types(self):
triples = agent_final_triples( triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42" self.FINAL_URI, previous_uri=self.PREV_URI,
) )
assert has_type(triples, self.FINAL_URI, PROV_ENTITY) assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION) assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE)
def test_final_derived_from_parent(self): def test_final_derived_from_previous(self):
"""Conclusion with iterations uses wasDerivedFrom."""
triples = agent_final_triples( triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42" self.FINAL_URI, previous_uri=self.PREV_URI,
) )
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is not None assert derived is not None
assert derived.o.iri == self.PARENT_URI assert derived.o.iri == self.PREV_URI
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is None
def test_final_generated_by_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasGeneratedBy."""
triples = agent_final_triples(
self.FINAL_URI, question_uri=self.SESSION_URI,
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is None
def test_final_label(self): def test_final_label(self):
triples = agent_final_triples( triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42" self.FINAL_URI, previous_uri=self.PREV_URI,
) )
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI) label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
assert label is not None assert label is not None
assert label.o.value == "Conclusion" assert label.o.value == "Conclusion"
def test_final_inline_answer(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="The answer is 42"
)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is not None
assert answer.o.value == "The answer is 42"
def test_final_document_reference(self): def test_final_document_reference(self):
triples = agent_final_triples( triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, self.FINAL_URI, previous_uri=self.PREV_URI,
document_id="urn:trustgraph:agent:sess/answer", document_id="urn:trustgraph:agent:sess/answer",
) )
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
@ -296,29 +328,9 @@ class TestAgentFinalTriples:
assert doc.o.type == IRI assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:agent:sess/answer" assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
def test_final_document_takes_precedence(self): def test_final_no_document(self):
triples = agent_final_triples( triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, self.FINAL_URI, previous_uri=self.PREV_URI,
answer="inline",
document_id="urn:doc:123",
) )
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is not None
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is None
def test_final_no_answer_or_document(self):
triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert answer is None
assert doc is None assert doc is None
def test_final_derives_from_session_when_no_iterations(self):
"""When agent answers immediately, final derives from session."""
session_uri = "urn:trustgraph:agent:test-session"
triples = agent_final_triples(
self.FINAL_URI, session_uri, answer="direct answer"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived.o.iri == session_uri

View file

@ -10,9 +10,11 @@ from trustgraph.api.explainability import (
EdgeSelection, EdgeSelection,
ExplainEntity, ExplainEntity,
Question, Question,
Grounding,
Exploration, Exploration,
Focus, Focus,
Synthesis, Synthesis,
Reflection,
Analysis, Analysis,
Conclusion, Conclusion,
parse_edge_selection_triples, parse_edge_selection_triples,
@ -20,11 +22,11 @@ from trustgraph.api.explainability import (
wire_triples_to_tuples, wire_triples_to_tuples,
ExplainabilityClient, ExplainabilityClient,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT, TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION, TG_ANALYSIS, TG_CONCLUSION,
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
RDF_TYPE, RDFS_LABEL, RDF_TYPE, RDFS_LABEL,
@ -71,6 +73,18 @@ class TestExplainEntityFromTriples:
assert isinstance(entity, Question) assert isinstance(entity, Question)
assert entity.question_type == "agent" assert entity.question_type == "agent"
def test_grounding(self):
triples = [
("urn:gnd:1", RDF_TYPE, TG_GROUNDING),
("urn:gnd:1", TG_CONCEPT, "machine learning"),
("urn:gnd:1", TG_CONCEPT, "neural networks"),
]
entity = ExplainEntity.from_triples("urn:gnd:1", triples)
assert isinstance(entity, Grounding)
assert len(entity.concepts) == 2
assert "machine learning" in entity.concepts
assert "neural networks" in entity.concepts
def test_exploration(self): def test_exploration(self):
triples = [ triples = [
("urn:exp:1", RDF_TYPE, TG_EXPLORATION), ("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
@ -89,6 +103,17 @@ class TestExplainEntityFromTriples:
assert isinstance(entity, Exploration) assert isinstance(entity, Exploration)
assert entity.chunk_count == 5 assert entity.chunk_count == 5
def test_exploration_with_entities(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
("urn:exp:3", TG_EDGE_COUNT, "10"),
("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"),
("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"),
]
entity = ExplainEntity.from_triples("urn:exp:3", triples)
assert isinstance(entity, Exploration)
assert len(entity.entities) == 2
def test_exploration_invalid_count(self): def test_exploration_invalid_count(self):
triples = [ triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION), ("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
@ -110,69 +135,76 @@ class TestExplainEntityFromTriples:
assert "urn:edge:1" in entity.selected_edge_uris assert "urn:edge:1" in entity.selected_edge_uris
assert "urn:edge:2" in entity.selected_edge_uris assert "urn:edge:2" in entity.selected_edge_uris
def test_synthesis_with_content(self): def test_synthesis_with_document(self):
triples = [ triples = [
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS), ("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:1", TG_CONTENT, "The answer is 42"), ("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"),
] ]
entity = ExplainEntity.from_triples("urn:syn:1", triples) entity = ExplainEntity.from_triples("urn:syn:1", triples)
assert isinstance(entity, Synthesis) assert isinstance(entity, Synthesis)
assert entity.content == "The answer is 42" assert entity.document_uri == "urn:doc:answer-1"
assert entity.document_uri == ""
def test_synthesis_with_document(self): def test_synthesis_no_document(self):
triples = [ triples = [
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS), ("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"),
] ]
entity = ExplainEntity.from_triples("urn:syn:2", triples) entity = ExplainEntity.from_triples("urn:syn:2", triples)
assert isinstance(entity, Synthesis) assert isinstance(entity, Synthesis)
assert entity.document_uri == "urn:doc:answer-1" assert entity.document_uri == ""
def test_reflection_thought(self):
triples = [
("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE),
("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"),
]
entity = ExplainEntity.from_triples("urn:ref:1", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "thought"
assert entity.document_uri == "urn:doc:thought-1"
def test_reflection_observation(self):
triples = [
("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ref:2", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "observation"
assert entity.document_uri == "urn:doc:obs-1"
def test_analysis(self): def test_analysis(self):
triples = [ triples = [
("urn:ana:1", RDF_TYPE, TG_ANALYSIS), ("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
("urn:ana:1", TG_THOUGHT, "I should search"),
("urn:ana:1", TG_ACTION, "graph-rag-query"), ("urn:ana:1", TG_ACTION, "graph-rag-query"),
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'), ("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
("urn:ana:1", TG_OBSERVATION, "Found results"), ("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
] ]
entity = ExplainEntity.from_triples("urn:ana:1", triples) entity = ExplainEntity.from_triples("urn:ana:1", triples)
assert isinstance(entity, Analysis) assert isinstance(entity, Analysis)
assert entity.thought == "I should search"
assert entity.action == "graph-rag-query" assert entity.action == "graph-rag-query"
assert entity.arguments == '{"query": "test"}' assert entity.arguments == '{"query": "test"}'
assert entity.observation == "Found results" assert entity.thought_uri == "urn:ref:thought-1"
assert entity.observation_uri == "urn:ref:obs-1"
def test_analysis_with_document_refs(self):
triples = [
("urn:ana:2", RDF_TYPE, TG_ANALYSIS),
("urn:ana:2", TG_ACTION, "search"),
("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"),
("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:2", triples)
assert isinstance(entity, Analysis)
assert entity.thought_document_uri == "urn:doc:thought-1"
assert entity.observation_document_uri == "urn:doc:obs-1"
def test_conclusion_with_answer(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_ANSWER, "The final answer"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.answer == "The final answer"
def test_conclusion_with_document(self): def test_conclusion_with_document(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_DOCUMENT, "urn:doc:final"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.document_uri == "urn:doc:final"
def test_conclusion_no_document(self):
triples = [ triples = [
("urn:conc:2", RDF_TYPE, TG_CONCLUSION), ("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
("urn:conc:2", TG_DOCUMENT, "urn:doc:final"),
] ]
entity = ExplainEntity.from_triples("urn:conc:2", triples) entity = ExplainEntity.from_triples("urn:conc:2", triples)
assert isinstance(entity, Conclusion) assert isinstance(entity, Conclusion)
assert entity.document_uri == "urn:doc:final" assert entity.document_uri == ""
def test_unknown_type(self): def test_unknown_type(self):
triples = [ triples = [
@ -457,25 +489,7 @@ class TestExplainabilityClientResolveLabel:
class TestExplainabilityClientContentFetching: class TestExplainabilityClientContentFetching:
def test_fetch_synthesis_inline_content(self): def test_fetch_document_content_from_librarian(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1", content="inline answer")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == "inline answer"
def test_fetch_synthesis_truncated_content(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
long_content = "x" * 20000
synthesis = Synthesis(uri="urn:syn:1", content=long_content)
result = client.fetch_synthesis_content(synthesis, api=None, max_content=100)
assert len(result) < 20000
assert result.endswith("... [truncated]")
def test_fetch_synthesis_from_librarian(self):
mock_flow = MagicMock() mock_flow = MagicMock()
mock_api = MagicMock() mock_api = MagicMock()
mock_library = MagicMock() mock_library = MagicMock()
@ -483,66 +497,32 @@ class TestExplainabilityClientContentFetching:
mock_library.get_document_content.return_value = b"librarian content" mock_library.get_document_content.return_value = b"librarian content"
client = ExplainabilityClient(mock_flow, retry_delay=0.0) client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis( result = client.fetch_document_content(
uri="urn:syn:1", "urn:document:abc123", api=mock_api
document_uri="urn:document:abc123"
) )
result = client.fetch_synthesis_content(synthesis, api=mock_api)
assert result == "librarian content" assert result == "librarian content"
def test_fetch_synthesis_no_content_or_document(self): def test_fetch_document_content_truncated(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == ""
def test_fetch_conclusion_inline(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
conclusion = Conclusion(uri="urn:conc:1", answer="42")
result = client.fetch_conclusion_content(conclusion, api=None)
assert result == "42"
def test_fetch_analysis_content_from_librarian(self):
mock_flow = MagicMock() mock_flow = MagicMock()
mock_api = MagicMock() mock_api = MagicMock()
mock_library = MagicMock() mock_library = MagicMock()
mock_api.library.return_value = mock_library mock_api.library.return_value = mock_library
mock_library.get_document_content.side_effect = [ mock_library.get_document_content.return_value = b"x" * 20000
b"thought content",
b"observation content",
]
client = ExplainabilityClient(mock_flow, retry_delay=0.0) client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis( result = client.fetch_document_content(
uri="urn:ana:1", "urn:doc:1", api=mock_api, max_content=100
action="search",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
) )
client.fetch_analysis_content(analysis, api=mock_api) assert len(result) < 20000
assert analysis.thought == "thought content" assert result.endswith("... [truncated]")
assert analysis.observation == "observation content"
def test_fetch_analysis_skips_when_inline_exists(self): def test_fetch_document_content_empty_uri(self):
mock_flow = MagicMock() mock_flow = MagicMock()
mock_api = MagicMock() mock_api = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0) client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis( result = client.fetch_document_content("", api=mock_api)
uri="urn:ana:1", assert result == ""
action="search",
thought="already have thought",
observation="already have observation",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
)
client.fetch_analysis_content(analysis, api=mock_api)
# Should not call librarian since inline content exists
mock_api.library.assert_not_called()
class TestExplainabilityClientDetectSessionType: class TestExplainabilityClientDetectSessionType:

View file

@ -13,6 +13,7 @@ from trustgraph.provenance.triples import (
derived_entity_triples, derived_entity_triples,
subgraph_provenance_triples, subgraph_provenance_triples,
question_triples, question_triples,
grounding_triples,
exploration_triples, exploration_triples,
focus_triples, focus_triples,
synthesis_triples, synthesis_triples,
@ -32,10 +33,12 @@ from trustgraph.provenance.namespaces import (
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS, TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_CONTENT, TG_DOCUMENT, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK, TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
GRAPH_SOURCE, GRAPH_RETRIEVAL, GRAPH_SOURCE, GRAPH_RETRIEVAL,
) )
@ -530,36 +533,77 @@ class TestQuestionTriples:
assert len(triples) == 6 assert len(triples) == 6
class TestExplorationTriples: class TestGroundingTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session" GND_URI = "urn:trustgraph:prov:grounding:test-session"
Q_URI = "urn:trustgraph:question:test-session" Q_URI = "urn:trustgraph:question:test-session"
def test_exploration_types(self): def test_grounding_types(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"])
assert has_type(triples, self.EXP_URI, PROV_ENTITY) assert has_type(triples, self.GND_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION) assert has_type(triples, self.GND_URI, TG_GROUNDING)
def test_exploration_generated_by_question(self): def test_grounding_generated_by_question(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI) gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
assert gen is not None assert gen is not None
assert gen.o.iri == self.Q_URI assert gen.o.iri == self.Q_URI
def test_grounding_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 3
values = {t.o.value for t in concepts}
assert values == {"AI", "ML", "robots"}
def test_grounding_empty_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 0
def test_grounding_label(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
label = find_triple(triples, RDFS_LABEL, self.GND_URI)
assert label is not None
assert label.o.value == "Grounding"
class TestExplorationTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
GND_URI = "urn:trustgraph:prov:grounding:test-session"
def test_exploration_types(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_exploration_derived_from_grounding(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert derived is not None
assert derived.o.iri == self.GND_URI
def test_exploration_edge_count(self): def test_exploration_edge_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None assert ec is not None
assert ec.o.value == "15" assert ec.o.value == "15"
def test_exploration_zero_edges(self): def test_exploration_zero_edges(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 0) triples = exploration_triples(self.EXP_URI, self.GND_URI, 0)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None assert ec is not None
assert ec.o.value == "0" assert ec.o.value == "0"
def test_exploration_with_entities(self):
entities = ["urn:e:machine-learning", "urn:e:neural-networks"]
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities)
ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI)
assert len(ent_triples) == 2
def test_exploration_triple_count(self): def test_exploration_triple_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 10) triples = exploration_triples(self.EXP_URI, self.GND_URI, 10)
assert len(triples) == 5 assert len(triples) == 5
@ -652,6 +696,7 @@ class TestSynthesisTriples:
triples = synthesis_triples(self.SYN_URI, self.FOC_URI) triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
assert has_type(triples, self.SYN_URI, PROV_ENTITY) assert has_type(triples, self.SYN_URI, PROV_ENTITY)
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS) assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
def test_synthesis_derived_from_focus(self): def test_synthesis_derived_from_focus(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI) triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
@ -659,12 +704,6 @@ class TestSynthesisTriples:
assert derived is not None assert derived is not None
assert derived.o.iri == self.FOC_URI assert derived.o.iri == self.FOC_URI
def test_synthesis_with_inline_content(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42")
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is not None
assert content.o.value == "The answer is 42"
def test_synthesis_with_document_reference(self): def test_synthesis_with_document_reference(self):
triples = synthesis_triples( triples = synthesis_triples(
self.SYN_URI, self.FOC_URI, self.SYN_URI, self.FOC_URI,
@ -675,23 +714,9 @@ class TestSynthesisTriples:
assert doc.o.type == IRI assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:question:abc/answer" assert doc.o.iri == "urn:trustgraph:question:abc/answer"
def test_synthesis_document_takes_precedence(self): def test_synthesis_no_document(self):
"""When both document_id and answer_text are provided, document_id wins."""
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
answer_text="inline",
document_id="urn:doc:123",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is not None
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None
def test_synthesis_no_content_or_document(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI) triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert content is None
assert doc is None assert doc is None
@ -723,31 +748,31 @@ class TestDocRagQuestionTriples:
class TestDocRagExplorationTriples: class TestDocRagExplorationTriples:
EXP_URI = "urn:trustgraph:docrag:test/exploration" EXP_URI = "urn:trustgraph:docrag:test/exploration"
Q_URI = "urn:trustgraph:docrag:test" GND_URI = "urn:trustgraph:docrag:test/grounding"
def test_docrag_exploration_types(self): def test_docrag_exploration_types(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5) triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
assert has_type(triples, self.EXP_URI, PROV_ENTITY) assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION) assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_docrag_exploration_generated_by(self): def test_docrag_exploration_derived_from_grounding(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5) triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI) derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert gen.o.iri == self.Q_URI assert derived.o.iri == self.GND_URI
def test_docrag_exploration_chunk_count(self): def test_docrag_exploration_chunk_count(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7) triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7)
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI) cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
assert cc.o.value == "7" assert cc.o.value == "7"
def test_docrag_exploration_without_chunk_ids(self): def test_docrag_exploration_without_chunk_ids(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3) triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3)
chunks = find_triples(triples, TG_SELECTED_CHUNK) chunks = find_triples(triples, TG_SELECTED_CHUNK)
assert len(chunks) == 0 assert len(chunks) == 0
def test_docrag_exploration_with_chunk_ids(self): def test_docrag_exploration_with_chunk_ids(self):
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"] chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids) triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids)
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI) chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
assert len(chunks) == 3 assert len(chunks) == 3
chunk_uris = {t.o.iri for t in chunks} chunk_uris = {t.o.iri for t in chunks}
@ -770,10 +795,9 @@ class TestDocRagSynthesisTriples:
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI) derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
assert derived.o.iri == self.EXP_URI assert derived.o.iri == self.EXP_URI
def test_docrag_synthesis_with_inline(self): def test_docrag_synthesis_has_answer_type(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer") triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
content = find_triple(triples, TG_CONTENT, self.SYN_URI) assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
assert content.o.value == "answer"
def test_docrag_synthesis_with_document(self): def test_docrag_synthesis_with_document(self):
triples = docrag_synthesis_triples( triples = docrag_synthesis_triples(
@ -781,5 +805,8 @@ class TestDocRagSynthesisTriples:
) )
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc.o.iri == "urn:doc:ans" assert doc.o.iri == "urn:doc:ans"
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None def test_docrag_synthesis_no_document(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is None

View file

@ -125,19 +125,15 @@ class TestQuery:
assert query.doc_limit == 50 assert query.doc_limit == 50
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_vector_method(self): async def test_extract_concepts(self):
"""Test Query.get_vector method calls embeddings client correctly""" """Test Query.extract_concepts extracts concepts from query"""
# Create mock DocumentRag with embeddings client
mock_rag = MagicMock() mock_rag = MagicMock()
mock_embeddings_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client mock_rag.prompt_client = mock_prompt_client
# Mock the embed method to return test vectors in batch format # Mock the prompt response with concept lines
# New format: [[[vectors_for_text1]]] - returns first text's vector set mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -145,20 +141,62 @@ class TestQuery:
verbose=False verbose=False
) )
# Call get_vector result = await query.extract_concepts("What is machine learning?")
test_query = "What documents are relevant?"
result = await query.get_vector(test_query)
# Verify embeddings client was called correctly (now expects list) mock_prompt_client.prompt.assert_called_once_with(
mock_embeddings_client.embed.assert_called_once_with([test_query]) "extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "artificial intelligence", "data patterns"]
# Verify result matches expected vectors (extracted from batch) @pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test Query.extract_concepts falls back to raw query when no concepts extracted"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is ML?")
assert result == ["What is ML?"]
@pytest.mark.asyncio
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method - returns vectors for each concept
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
concepts = ["machine learning", "data patterns"]
result = await query.get_vectors(concepts)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors assert result == expected_vectors
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_docs_method(self): async def test_get_docs_method(self):
"""Test Query.get_docs method retrieves documents correctly""" """Test Query.get_docs method retrieves documents correctly"""
# Create mock DocumentRag with clients
mock_rag = MagicMock() mock_rag = MagicMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
@ -170,10 +208,8 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch mock_rag.fetch_chunk = mock_fetch
# Mock the embedding and document query responses # Mock embeddings - one vector per concept
# New batch format: [[[vectors]]] - get_vector extracts [0] mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]]
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock document embeddings returns ChunkMatch objects # Mock document embeddings returns ChunkMatch objects
mock_match1 = MagicMock() mock_match1 = MagicMock()
@ -184,7 +220,6 @@ class TestQuery:
mock_match2.score = 0.85 mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -193,16 +228,16 @@ class TestQuery:
doc_limit=15 doc_limit=15
) )
# Call get_docs # Call get_docs with concepts list
test_query = "Find relevant documents" concepts = ["test concept"]
result = await query.get_docs(test_query) result = await query.get_docs(concepts)
# Verify embeddings client was called (now expects list) # Verify embeddings client was called with concepts
mock_embeddings_client.embed.assert_called_once_with([test_query]) mock_embeddings_client.embed.assert_called_once_with(concepts)
# Verify doc embeddings client was called correctly (with extracted vector) # Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors, vector=[0.1, 0.2, 0.3],
limit=15, limit=15,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -218,14 +253,17 @@ class TestQuery:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_method(self, mock_fetch_chunk): async def test_document_rag_query_method(self, mock_fetch_chunk):
"""Test DocumentRag.query method orchestrates full document RAG pipeline""" """Test DocumentRag.query method orchestrates full document RAG pipeline"""
# Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock embeddings and document embeddings responses # Mock concept extraction
# New batch format: [[[vectors]]] - get_vector extracts [0] mock_prompt_client.prompt.return_value = "test concept"
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
mock_match1 = MagicMock() mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3" mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9 mock_match1.score = 0.9
@ -234,11 +272,9 @@ class TestQuery:
mock_match2.score = 0.8 mock_match2.score = 0.8
expected_response = "This is the document RAG response" expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
@ -247,7 +283,6 @@ class TestQuery:
verbose=False verbose=False
) )
# Call DocumentRag.query
result = await document_rag.query( result = await document_rag.query(
query="test query", query="test query",
user="test_user", user="test_user",
@ -255,12 +290,18 @@ class TestQuery:
doc_limit=10 doc_limit=10
) )
# Verify embeddings client was called (now expects list) # Verify concept extraction was called
mock_embeddings_client.embed.assert_called_once_with(["test query"]) mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "test query"}
)
# Verify doc embeddings client was called (with extracted vector) # Verify embeddings called with extracted concepts
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors, vector=[0.1, 0.2, 0.3],
limit=10, limit=10,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -270,23 +311,23 @@ class TestQuery:
mock_prompt_client.document_prompt.assert_called_once() mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "test query" assert call_args.kwargs["query"] == "test query"
# Documents should be fetched content, not chunk_ids
docs = call_args.kwargs["documents"] docs = call_args.kwargs["documents"]
assert "Relevant document content" in docs assert "Relevant document content" in docs
assert "Another document" in docs assert "Another document" in docs
# Verify result
assert result == expected_response assert result == expected_response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk): async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag.query method with default parameters""" """Test DocumentRag.query method with default parameters"""
# Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format) # Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_match = MagicMock() mock_match = MagicMock()
mock_match.chunk_id = "doc/c5" mock_match.chunk_id = "doc/c5"
@ -294,7 +335,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match] mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response" mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
@ -302,10 +342,9 @@ class TestQuery:
fetch_chunk=mock_fetch_chunk fetch_chunk=mock_fetch_chunk
) )
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query") result = await document_rag.query("simple query")
# Verify default parameters were used (vector extracted from batch) # Verify default parameters were used
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]], vector=[[0.1, 0.2]],
limit=20, # Default doc_limit limit=20, # Default doc_limit
@ -318,7 +357,6 @@ class TestQuery:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self): async def test_get_docs_with_verbose_output(self):
"""Test Query.get_docs method with verbose logging""" """Test Query.get_docs method with verbose logging"""
# Create mock DocumentRag with clients
mock_rag = MagicMock() mock_rag = MagicMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
@ -330,14 +368,13 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch mock_rag.fetch_chunk = mock_fetch
# Mock responses (batch format) # Mock responses - one vector per concept
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_match = MagicMock() mock_match = MagicMock()
mock_match.chunk_id = "doc/c6" mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88 mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match] mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -346,14 +383,12 @@ class TestQuery:
doc_limit=5 doc_limit=5
) )
# Call get_docs # Call get_docs with concepts
result = await query.get_docs("verbose test") result = await query.get_docs(["verbose test"])
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose test"]) mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
mock_doc_embeddings_client.query.assert_called_once() mock_doc_embeddings_client.query.assert_called_once()
# Verify result is tuple of (docs, chunk_ids) with fetched content
docs, chunk_ids = result docs, chunk_ids = result
assert "Verbose test doc" in docs assert "Verbose test doc" in docs
assert "doc/c6" in chunk_ids assert "doc/c6" in chunk_ids
@ -361,12 +396,14 @@ class TestQuery:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk): async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag.query method with verbose logging enabled""" """Test DocumentRag.query method with verbose logging enabled"""
# Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format) # Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_match = MagicMock() mock_match = MagicMock()
mock_match.chunk_id = "doc/c7" mock_match.chunk_id = "doc/c7"
@ -374,7 +411,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match] mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response" mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
@ -383,14 +419,11 @@ class TestQuery:
verbose=True verbose=True
) )
# Call DocumentRag.query
result = await document_rag.query("verbose query test") result = await document_rag.query("verbose query test")
# Verify all clients were called (now expects list) mock_embeddings_client.embed.assert_called_once()
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
mock_doc_embeddings_client.query.assert_called_once() mock_doc_embeddings_client.query.assert_called_once()
# Verify prompt client was called with fetched content
call_args = mock_prompt_client.document_prompt.call_args call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "verbose query test" assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"] assert "Verbose doc content" in call_args.kwargs["documents"]
@ -400,23 +433,20 @@ class TestQuery:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_docs_with_empty_results(self): async def test_get_docs_with_empty_results(self):
"""Test Query.get_docs method when no documents are found""" """Test Query.get_docs method when no documents are found"""
# Create mock DocumentRag with clients
mock_rag = MagicMock() mock_rag = MagicMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk (won't be called if no chunk_ids)
async def mock_fetch(chunk_id, user): async def mock_fetch(chunk_id, user):
return f"Content for {chunk_id}" return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty chunk_id list (batch format) # Mock responses - empty results
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found mock_doc_embeddings_client.query.return_value = []
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -424,30 +454,27 @@ class TestQuery:
verbose=False verbose=False
) )
# Call get_docs result = await query.get_docs(["query with no results"])
result = await query.get_docs("query with no results")
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["query with no results"]) mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
mock_doc_embeddings_client.query.assert_called_once() mock_doc_embeddings_client.query.assert_called_once()
# Verify empty result is returned (tuple of empty lists)
assert result == ([], []) assert result == ([], [])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk): async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
"""Test DocumentRag.query method when no documents are retrieved""" """Test DocumentRag.query method when no documents are retrieved"""
# Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock responses - no chunk_ids found (batch format) # Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]] mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response" mock_prompt_client.document_prompt.return_value = "No documents found response"
# Initialize DocumentRag
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
@ -456,10 +483,8 @@ class TestQuery:
verbose=False verbose=False
) )
# Call DocumentRag.query
result = await document_rag.query("query with no matching docs") result = await document_rag.query("query with no matching docs")
# Verify prompt client was called with empty document list
mock_prompt_client.document_prompt.assert_called_once_with( mock_prompt_client.document_prompt.assert_called_once_with(
query="query with no matching docs", query="query with no matching docs",
documents=[] documents=[]
@ -468,18 +493,15 @@ class TestQuery:
assert result == "No documents found response" assert result == "No documents found response"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_vector_with_verbose(self): async def test_get_vectors_with_verbose(self):
"""Test Query.get_vector method with verbose logging""" """Test Query.get_vectors method with verbose logging"""
# Create mock DocumentRag with embeddings client
mock_rag = MagicMock() mock_rag = MagicMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method (batch format)
expected_vectors = [[0.9, 1.0, 1.1]] expected_vectors = [[0.9, 1.0, 1.1]]
mock_embeddings_client.embed.return_value = [expected_vectors] mock_embeddings_client.embed.return_value = expected_vectors
# Initialize Query with verbose=True
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -487,40 +509,40 @@ class TestQuery:
verbose=True verbose=True
) )
# Call get_vector result = await query.get_vectors(["verbose vector test"])
result = await query.get_vector("verbose vector test")
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"]) mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
# Verify result (extracted from batch)
assert result == expected_vectors assert result == expected_vectors
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_integration_flow(self, mock_fetch_chunk): async def test_document_rag_integration_flow(self, mock_fetch_chunk):
"""Test complete DocumentRag integration with realistic data flow""" """Test complete DocumentRag integration with realistic data flow"""
# Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock realistic responses (batch format)
query_text = "What is machine learning?" query_text = "What is machine learning?"
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors] # Mock concept extraction
mock_matches = [] mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock() # Mock embeddings - one vector per concept
mock_match.chunk_id = chunk_id query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
mock_match.score = 0.9 mock_embeddings_client.embed.return_value = query_vectors
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches # Each concept query returns some matches
mock_matches_1 = [
MagicMock(chunk_id="doc/ml1", score=0.9),
MagicMock(chunk_id="doc/ml2", score=0.85),
]
mock_matches_2 = [
MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
@ -529,7 +551,6 @@ class TestQuery:
verbose=False verbose=False
) )
# Execute full pipeline
result = await document_rag.query( result = await document_rag.query(
query=query_text, query=query_text,
user="research_user", user="research_user",
@ -537,26 +558,69 @@ class TestQuery:
doc_limit=25 doc_limit=25
) )
# Verify complete pipeline execution (now expects list) # Verify concept extraction
mock_embeddings_client.embed.assert_called_once_with([query_text]) mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
mock_doc_embeddings_client.query.assert_called_once_with( variables={"query": query_text}
vector=query_vectors,
limit=25,
user="research_user",
collection="ml_knowledge"
) )
# Verify embeddings called with concepts
mock_embeddings_client.embed.assert_called_once_with(
["machine learning", "artificial intelligence"]
)
# Verify two per-concept queries were made (25 // 2 = 12 per concept)
assert mock_doc_embeddings_client.query.call_count == 2
# Verify prompt client was called with fetched document content # Verify prompt client was called with fetched document content
mock_prompt_client.document_prompt.assert_called_once() mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == query_text assert call_args.kwargs["query"] == query_text
# Verify documents were fetched from chunk_ids # Verify documents were fetched and deduplicated
docs = call_args.kwargs["documents"] docs = call_args.kwargs["documents"]
assert "Machine learning is a subset of artificial intelligence..." in docs assert "Machine learning is a subset of artificial intelligence..." in docs
assert "ML algorithms learn patterns from data to make predictions..." in docs assert "ML algorithms learn patterns from data to make predictions..." in docs
assert "Common ML techniques include supervised and unsupervised learning..." in docs assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
# Verify final result
assert result == final_response assert result == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):
"""Test that get_docs deduplicates chunks across multiple concepts"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Two concepts → two vectors
mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Both queries return overlapping chunks
match_a = MagicMock(chunk_id="doc/c1", score=0.9)
match_b = MagicMock(chunk_id="doc/c2", score=0.8)
match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate
mock_doc_embeddings_client.query.side_effect = [
[match_a, match_b],
[match_c],
]
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
doc_limit=10
)
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
assert len(chunk_ids) == 2 # doc/c1 only counted once
assert "doc/c1" in chunk_ids
assert "doc/c2" in chunk_ids

View file

@ -19,7 +19,7 @@ class TestGraphRag:
mock_embeddings_client = MagicMock() mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock() mock_triples_client = MagicMock()
# Initialize GraphRag # Initialize GraphRag
graph_rag = GraphRag( graph_rag = GraphRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
@ -27,7 +27,7 @@ class TestGraphRag:
graph_embeddings_client=mock_graph_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client triples_client=mock_triples_client
) )
# Verify initialization # Verify initialization
assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client assert graph_rag.embeddings_client == mock_embeddings_client
@ -45,7 +45,7 @@ class TestGraphRag:
mock_embeddings_client = MagicMock() mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock() mock_triples_client = MagicMock()
# Initialize GraphRag with verbose=True # Initialize GraphRag with verbose=True
graph_rag = GraphRag( graph_rag = GraphRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
@ -54,7 +54,7 @@ class TestGraphRag:
triples_client=mock_triples_client, triples_client=mock_triples_client,
verbose=True verbose=True
) )
# Verify initialization # Verify initialization
assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client assert graph_rag.embeddings_client == mock_embeddings_client
@ -73,7 +73,7 @@ class TestQuery:
"""Test Query initialization with default parameters""" """Test Query initialization with default parameters"""
# Create mock GraphRag # Create mock GraphRag
mock_rag = MagicMock() mock_rag = MagicMock()
# Initialize Query with defaults # Initialize Query with defaults
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
@ -81,7 +81,7 @@ class TestQuery:
collection="test_collection", collection="test_collection",
verbose=False verbose=False
) )
# Verify initialization # Verify initialization
assert query.rag == mock_rag assert query.rag == mock_rag
assert query.user == "test_user" assert query.user == "test_user"
@ -96,7 +96,7 @@ class TestQuery:
"""Test Query initialization with custom parameters""" """Test Query initialization with custom parameters"""
# Create mock GraphRag # Create mock GraphRag
mock_rag = MagicMock() mock_rag = MagicMock()
# Initialize Query with custom parameters # Initialize Query with custom parameters
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
@ -108,7 +108,7 @@ class TestQuery:
max_subgraph_size=2000, max_subgraph_size=2000,
max_path_length=3 max_path_length=3
) )
# Verify initialization # Verify initialization
assert query.rag == mock_rag assert query.rag == mock_rag
assert query.user == "custom_user" assert query.user == "custom_user"
@ -120,18 +120,16 @@ class TestQuery:
assert query.max_path_length == 3 assert query.max_path_length == 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_vector_method(self): async def test_get_vectors_method(self):
"""Test Query.get_vector method calls embeddings client correctly""" """Test Query.get_vectors method calls embeddings client correctly"""
# Create mock GraphRag with embeddings client
mock_rag = MagicMock() mock_rag = MagicMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method to return test vectors (batch format)
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query # Mock embed to return vectors for a list of concepts
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -139,29 +137,22 @@ class TestQuery:
verbose=False verbose=False
) )
# Call get_vector concepts = ["machine learning", "neural networks"]
test_query = "What is the capital of France?" result = await query.get_vectors(concepts)
result = await query.get_vector(test_query)
# Verify embeddings client was called correctly (now expects list) mock_embeddings_client.embed.assert_called_once_with(concepts)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
assert result == expected_vectors assert result == expected_vectors
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_vector_method_with_verbose(self): async def test_get_vectors_method_with_verbose(self):
"""Test Query.get_vector method with verbose output""" """Test Query.get_vectors method with verbose output"""
# Create mock GraphRag with embeddings client
mock_rag = MagicMock() mock_rag = MagicMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method (batch format)
expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query with verbose=True expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -169,48 +160,87 @@ class TestQuery:
verbose=True verbose=True
) )
# Call get_vector result = await query.get_vectors(["test concept"])
test_query = "Test query for embeddings"
result = await query.get_vector(test_query)
# Verify embeddings client was called correctly (now expects list) mock_embeddings_client.embed.assert_called_once_with(["test concept"])
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
assert result == expected_vectors assert result == expected_vectors
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_entities_method(self): async def test_extract_concepts(self):
"""Test Query.get_entities method retrieves entities correctly""" """Test Query.extract_concepts parses LLM response into concept list"""
# Create mock GraphRag with clients
mock_rag = MagicMock() mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is machine learning?")
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "neural networks"]
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test extract_concepts falls back to raw query when LLM returns empty"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("test query")
assert result == ["test query"]
@pytest.mark.asyncio
async def test_get_entities_method(self):
"""Test Query.get_entities extracts concepts, embeds, and retrieves entities"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
mock_rag.graph_embeddings_client = mock_graph_embeddings_client mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# Mock the embedding and entity query responses (batch format)
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock EntityMatch objects with entity as Term-like object # extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = ""
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
# Mock entity matches
mock_entity1 = MagicMock() mock_entity1 = MagicMock()
mock_entity1.type = "i" # IRI type mock_entity1.type = "i"
mock_entity1.iri = "entity1" mock_entity1.iri = "entity1"
mock_match1 = MagicMock() mock_match1 = MagicMock()
mock_match1.entity = mock_entity1 mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock() mock_entity2 = MagicMock()
mock_entity2.type = "i" # IRI type mock_entity2.type = "i"
mock_entity2.iri = "entity2" mock_entity2.iri = "entity2"
mock_match2 = MagicMock() mock_match2 = MagicMock()
mock_match2.entity = mock_entity2 mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -219,35 +249,23 @@ class TestQuery:
entity_limit=25 entity_limit=25
) )
# Call get_entities entities, concepts = await query.get_entities("Find related entities")
test_query = "Find related entities"
result = await query.get_entities(test_query)
# Verify embeddings client was called (now expects list) # Verify embeddings client was called with the fallback concept
mock_embeddings_client.embed.assert_called_once_with([test_query]) mock_embeddings_client.embed.assert_called_once_with(["Find related entities"])
# Verify graph embeddings client was called correctly (with extracted vector) # Verify result
mock_graph_embeddings_client.query.assert_called_once_with( assert entities == ["entity1", "entity2"]
vector=test_vectors, assert concepts == ["Find related entities"]
limit=25,
user="test_user",
collection="test_collection"
)
# Verify result is list of entity strings
assert result == ["entity1", "entity2"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_maybe_label_with_cached_label(self): async def test_maybe_label_with_cached_label(self):
"""Test Query.maybe_label method with cached label""" """Test Query.maybe_label method with cached label"""
# Create mock GraphRag with label cache
mock_rag = MagicMock() mock_rag = MagicMock()
# Create mock LRUCacheWithTTL
mock_cache = MagicMock() mock_cache = MagicMock()
mock_cache.get.return_value = "Entity One Label" mock_cache.get.return_value = "Entity One Label"
mock_rag.label_cache = mock_cache mock_rag.label_cache = mock_cache
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -255,32 +273,25 @@ class TestQuery:
verbose=False verbose=False
) )
# Call maybe_label with cached entity
result = await query.maybe_label("entity1") result = await query.maybe_label("entity1")
# Verify cached label is returned
assert result == "Entity One Label" assert result == "Entity One Label"
# Verify cache was checked with proper key format (user:collection:entity)
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1") mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self): async def test_maybe_label_with_label_lookup(self):
"""Test Query.maybe_label method with database label lookup""" """Test Query.maybe_label method with database label lookup"""
# Create mock GraphRag with triples client
mock_rag = MagicMock() mock_rag = MagicMock()
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock() mock_cache = MagicMock()
mock_cache.get.return_value = None mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock() mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client mock_rag.triples_client = mock_triples_client
# Mock triple result with label
mock_triple = MagicMock() mock_triple = MagicMock()
mock_triple.o = "Human Readable Label" mock_triple.o = "Human Readable Label"
mock_triples_client.query.return_value = [mock_triple] mock_triples_client.query.return_value = [mock_triple]
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -288,20 +299,18 @@ class TestQuery:
verbose=False verbose=False
) )
# Call maybe_label
result = await query.maybe_label("http://example.com/entity") result = await query.maybe_label("http://example.com/entity")
# Verify triples client was called correctly
mock_triples_client.query.assert_called_once_with( mock_triples_client.query.assert_called_once_with(
s="http://example.com/entity", s="http://example.com/entity",
p="http://www.w3.org/2000/01/rdf-schema#label", p="http://www.w3.org/2000/01/rdf-schema#label",
o=None, o=None,
limit=1, limit=1,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection",
g=""
) )
# Verify result and cache update with proper key
assert result == "Human Readable Label" assert result == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity" cache_key = "test_user:test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label") mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@ -309,40 +318,34 @@ class TestQuery:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_maybe_label_with_no_label_found(self): async def test_maybe_label_with_no_label_found(self):
"""Test Query.maybe_label method when no label is found""" """Test Query.maybe_label method when no label is found"""
# Create mock GraphRag with triples client
mock_rag = MagicMock() mock_rag = MagicMock()
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock() mock_cache = MagicMock()
mock_cache.get.return_value = None mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock() mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client mock_rag.triples_client = mock_triples_client
# Mock empty result (no label found)
mock_triples_client.query.return_value = [] mock_triples_client.query.return_value = []
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
collection="test_collection", collection="test_collection",
verbose=False verbose=False
) )
# Call maybe_label
result = await query.maybe_label("unlabeled_entity") result = await query.maybe_label("unlabeled_entity")
# Verify triples client was called
mock_triples_client.query.assert_called_once_with( mock_triples_client.query.assert_called_once_with(
s="unlabeled_entity", s="unlabeled_entity",
p="http://www.w3.org/2000/01/rdf-schema#label", p="http://www.w3.org/2000/01/rdf-schema#label",
o=None, o=None,
limit=1, limit=1,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection",
g=""
) )
# Verify result is entity itself and cache is updated
assert result == "unlabeled_entity" assert result == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity" cache_key = "test_user:test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@ -350,29 +353,25 @@ class TestQuery:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self): async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery""" """Test Query.follow_edges method basic triple discovery"""
# Create mock GraphRag with triples client
mock_rag = MagicMock() mock_rag = MagicMock()
mock_triples_client = AsyncMock() mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client mock_rag.triples_client = mock_triples_client
# Mock triple results for different query patterns
mock_triple1 = MagicMock() mock_triple1 = MagicMock()
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1" mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
mock_triple2 = MagicMock() mock_triple2 = MagicMock()
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2" mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
mock_triple3 = MagicMock() mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1" mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
mock_triples_client.query_stream.side_effect = [ mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent, p=None, o=None [mock_triple1], # s=ent
[mock_triple2], # s=None, p=ent, o=None [mock_triple2], # p=ent
[mock_triple3], # s=None, p=None, o=ent [mock_triple3], # o=ent
] ]
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -380,29 +379,25 @@ class TestQuery:
verbose=False, verbose=False,
triple_limit=10 triple_limit=10
) )
# Call follow_edges
subgraph = set() subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=1) await query.follow_edges("entity1", subgraph, path_length=1)
# Verify all three query patterns were called
assert mock_triples_client.query_stream.call_count == 3 assert mock_triples_client.query_stream.call_count == 3
# Verify query_stream calls
mock_triples_client.query_stream.assert_any_call( mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10, s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20 user="test_user", collection="test_collection", batch_size=20, g=""
) )
mock_triples_client.query_stream.assert_any_call( mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10, s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20 user="test_user", collection="test_collection", batch_size=20, g=""
) )
mock_triples_client.query_stream.assert_any_call( mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10, s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection", batch_size=20 user="test_user", collection="test_collection", batch_size=20, g=""
) )
# Verify subgraph contains discovered triples
expected_subgraph = { expected_subgraph = {
("entity1", "predicate1", "object1"), ("entity1", "predicate1", "object1"),
("subject2", "entity1", "object2"), ("subject2", "entity1", "object2"),
@ -413,38 +408,30 @@ class TestQuery:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_follow_edges_with_path_length_zero(self): async def test_follow_edges_with_path_length_zero(self):
"""Test Query.follow_edges method with path_length=0""" """Test Query.follow_edges method with path_length=0"""
# Create mock GraphRag
mock_rag = MagicMock() mock_rag = MagicMock()
mock_triples_client = AsyncMock() mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client mock_rag.triples_client = mock_triples_client
# Initialize Query
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
collection="test_collection", collection="test_collection",
verbose=False verbose=False
) )
# Call follow_edges with path_length=0
subgraph = set() subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0) await query.follow_edges("entity1", subgraph, path_length=0)
# Verify no queries were made
mock_triples_client.query_stream.assert_not_called() mock_triples_client.query_stream.assert_not_called()
# Verify subgraph remains empty
assert subgraph == set() assert subgraph == set()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_follow_edges_with_max_subgraph_size_limit(self): async def test_follow_edges_with_max_subgraph_size_limit(self):
"""Test Query.follow_edges method respects max_subgraph_size""" """Test Query.follow_edges method respects max_subgraph_size"""
# Create mock GraphRag
mock_rag = MagicMock() mock_rag = MagicMock()
mock_triples_client = AsyncMock() mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client mock_rag.triples_client = mock_triples_client
# Initialize Query with small max_subgraph_size
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
@ -452,23 +439,17 @@ class TestQuery:
verbose=False, verbose=False,
max_subgraph_size=2 max_subgraph_size=2
) )
# Pre-populate subgraph to exceed limit
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")} subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
# Call follow_edges
await query.follow_edges("entity1", subgraph, path_length=1) await query.follow_edges("entity1", subgraph, path_length=1)
# Verify no queries were made due to size limit
mock_triples_client.query_stream.assert_not_called() mock_triples_client.query_stream.assert_not_called()
# Verify subgraph unchanged
assert len(subgraph) == 3 assert len(subgraph) == 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_subgraph_method(self): async def test_get_subgraph_method(self):
"""Test Query.get_subgraph method orchestrates entity and edge discovery""" """Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
# Create mock Query that patches get_entities and follow_edges_batch
mock_rag = MagicMock() mock_rag = MagicMock()
query = Query( query = Query(
@ -479,130 +460,119 @@ class TestQuery:
max_path_length=1 max_path_length=1
) )
# Mock get_entities to return test entities # Mock get_entities to return (entities, concepts) tuple
query.get_entities = AsyncMock(return_value=["entity1", "entity2"]) query.get_entities = AsyncMock(
return_value=(["entity1", "entity2"], ["concept1"])
)
# Mock follow_edges_batch to return test triples
query.follow_edges_batch = AsyncMock(return_value={ query.follow_edges_batch = AsyncMock(return_value={
("entity1", "predicate1", "object1"), ("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2") ("entity2", "predicate2", "object2")
}) })
# Call get_subgraph subgraph, entities, concepts = await query.get_subgraph("test query")
result = await query.get_subgraph("test query")
# Verify get_entities was called
query.get_entities.assert_called_once_with("test query") query.get_entities.assert_called_once_with("test query")
# Verify follow_edges_batch was called with entities and max_path_length
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
# Verify result is list format and contains expected triples assert isinstance(subgraph, list)
assert isinstance(result, list) assert len(subgraph) == 2
assert len(result) == 2 assert ("entity1", "predicate1", "object1") in subgraph
assert ("entity1", "predicate1", "object1") in result assert ("entity2", "predicate2", "object2") in subgraph
assert ("entity2", "predicate2", "object2") in result assert entities == ["entity1", "entity2"]
assert concepts == ["concept1"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_labelgraph_method(self): async def test_get_labelgraph_method(self):
"""Test Query.get_labelgraph method converts entities to labels""" """Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
# Create mock Query
mock_rag = MagicMock() mock_rag = MagicMock()
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
user="test_user", user="test_user",
collection="test_collection", collection="test_collection",
verbose=False, verbose=False,
max_subgraph_size=100 max_subgraph_size=100
) )
# Mock get_subgraph to return test triples
test_subgraph = [ test_subgraph = [
("entity1", "predicate1", "object1"), ("entity1", "predicate1", "object1"),
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
("entity3", "predicate3", "object3") ("entity3", "predicate3", "object3")
] ]
query.get_subgraph = AsyncMock(return_value=test_subgraph) test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
# Mock maybe_label to return human-readable labels query.get_subgraph = AsyncMock(
return_value=(test_subgraph, test_entities, test_concepts)
)
async def mock_maybe_label(entity): async def mock_maybe_label(entity):
label_map = { label_map = {
"entity1": "Human Entity One", "entity1": "Human Entity One",
"predicate1": "Human Predicate One", "predicate1": "Human Predicate One",
"object1": "Human Object One", "object1": "Human Object One",
"entity3": "Human Entity Three", "entity3": "Human Entity Three",
"predicate3": "Human Predicate Three", "predicate3": "Human Predicate Three",
"object3": "Human Object Three" "object3": "Human Object Three"
} }
return label_map.get(entity, entity) return label_map.get(entity, entity)
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
# Call get_labelgraph
labeled_edges, uri_map = await query.get_labelgraph("test query")
# Verify get_subgraph was called query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
query.get_subgraph.assert_called_once_with("test query") query.get_subgraph.assert_called_once_with("test query")
# Verify label triples are filtered out # Label triples filtered out
assert len(labeled_edges) == 2 # Label triple should be excluded assert len(labeled_edges) == 2
# Verify maybe_label was called for non-label triples # maybe_label called for non-label triples
expected_calls = [
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
]
assert query.maybe_label.call_count == 6 assert query.maybe_label.call_count == 6
# Verify result contains human-readable labels
expected_edges = [ expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"), ("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three") ("Human Entity Three", "Human Predicate Three", "Human Object Three")
] ]
assert labeled_edges == expected_edges assert labeled_edges == expected_edges
# Verify uri_map maps labeled edges back to original URIs
assert len(uri_map) == 2 assert len(uri_map) == 2
assert entities == test_entities
assert concepts == test_concepts
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_rag_query_method(self): async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance""" """Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
import json import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id from trustgraph.retrieval.graph_rag.graph_rag import edge_id
# Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock() mock_triples_client = AsyncMock()
# Mock prompt client responses for two-step process
expected_response = "This is the RAG response" expected_response = "This is the RAG response"
test_labelgraph = [("Subject", "Predicate", "Object")] test_labelgraph = [("Subject", "Predicate", "Object")]
# Compute the edge ID for the test edge
test_edge_id = edge_id("Subject", "Predicate", "Object") test_edge_id = edge_id("Subject", "Predicate", "Object")
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
test_uri_map = { test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
} }
test_entities = ["http://example.org/subject"]
test_concepts = ["test concept"]
# Mock edge selection response (JSONL format) # Mock prompt responses for the multi-step process
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
# Configure prompt mock to return different responses based on prompt name
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection": if prompt_name == "extract-concepts":
return edge_selection_response return "" # Falls back to raw query
elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9})
elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
elif prompt_name == "kg-synthesis": elif prompt_name == "kg-synthesis":
return expected_response return expected_response
return "" return ""
mock_prompt_client.prompt = mock_prompt mock_prompt_client.prompt = mock_prompt
# Initialize GraphRag
graph_rag = GraphRag( graph_rag = GraphRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
@ -611,27 +581,20 @@ class TestQuery:
verbose=False verbose=False
) )
# We need to patch the Query class's get_labelgraph method # Patch Query.get_labelgraph to return test data
original_query_init = Query.__init__
original_get_labelgraph = Query.get_labelgraph original_get_labelgraph = Query.get_labelgraph
def mock_query_init(self, *args, **kwargs):
original_query_init(self, *args, **kwargs)
async def mock_get_labelgraph(self, query_text): async def mock_get_labelgraph(self, query_text):
return test_labelgraph, test_uri_map return test_labelgraph, test_uri_map, test_entities, test_concepts
Query.__init__ = mock_query_init
Query.get_labelgraph = mock_get_labelgraph Query.get_labelgraph = mock_get_labelgraph
# Collect provenance emitted via callback
provenance_events = [] provenance_events = []
async def collect_provenance(triples, prov_id): async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id)) provenance_events.append((triples, prov_id))
try: try:
# Call GraphRag.query with provenance callback
response = await graph_rag.query( response = await graph_rag.query(
query="test query", query="test query",
user="test_user", user="test_user",
@ -641,25 +604,22 @@ class TestQuery:
explain_callback=collect_provenance explain_callback=collect_provenance
) )
# Verify response text
assert response == expected_response assert response == expected_response
# Verify provenance was emitted incrementally (4 events: question, exploration, focus, synthesis) # 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 4 assert len(provenance_events) == 5
# Verify each event has triples and a URN
for triples, prov_id in provenance_events: for triples, prov_id in provenance_events:
assert isinstance(triples, list) assert isinstance(triples, list)
assert len(triples) > 0 assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:") assert prov_id.startswith("urn:trustgraph:")
# Verify order: question, exploration, focus, synthesis # Verify order
assert "question" in provenance_events[0][1] assert "question" in provenance_events[0][1]
assert "exploration" in provenance_events[1][1] assert "grounding" in provenance_events[1][1]
assert "focus" in provenance_events[2][1] assert "exploration" in provenance_events[2][1]
assert "synthesis" in provenance_events[3][1] assert "focus" in provenance_events[3][1]
assert "synthesis" in provenance_events[4][1]
finally: finally:
# Restore original methods Query.get_labelgraph = original_get_labelgraph
Query.__init__ = original_query_init
Query.get_labelgraph = original_get_labelgraph

View file

@ -75,9 +75,11 @@ from .explainability import (
ExplainabilityClient, ExplainabilityClient,
ExplainEntity, ExplainEntity,
Question, Question,
Grounding,
Exploration, Exploration,
Focus, Focus,
Synthesis, Synthesis,
Reflection,
Analysis, Analysis,
Conclusion, Conclusion,
EdgeSelection, EdgeSelection,

View file

@ -18,25 +18,28 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge" TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge" TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning" TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document" TG_DOCUMENT = TG + "document"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_CHUNK_COUNT = TG + "chunkCount" TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk" TG_SELECTED_CHUNK = TG + "selectedChunk"
TG_THOUGHT = TG + "thought" TG_THOUGHT = TG + "thought"
TG_ACTION = TG + "action" TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments" TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation" TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
# Entity types # Entity types
TG_QUESTION = TG + "Question" TG_QUESTION = TG + "Question"
TG_GROUNDING = TG + "Grounding"
TG_EXPLORATION = TG + "Exploration" TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus" TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis" TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis" TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion" TG_CONCLUSION = TG + "Conclusion"
TG_ANSWER_TYPE = TG + "Answer"
TG_REFLECTION_TYPE = TG + "Reflection"
TG_THOUGHT_TYPE = TG + "Thought"
TG_OBSERVATION_TYPE = TG + "Observation"
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
TG_AGENT_QUESTION = TG + "AgentQuestion" TG_AGENT_QUESTION = TG + "AgentQuestion"
@ -73,12 +76,16 @@ class ExplainEntity:
if TG_GRAPH_RAG_QUESTION in types or TG_DOC_RAG_QUESTION in types or TG_AGENT_QUESTION in types: if TG_GRAPH_RAG_QUESTION in types or TG_DOC_RAG_QUESTION in types or TG_AGENT_QUESTION in types:
return Question.from_triples(uri, triples, types) return Question.from_triples(uri, triples, types)
elif TG_GROUNDING in types:
return Grounding.from_triples(uri, triples)
elif TG_EXPLORATION in types: elif TG_EXPLORATION in types:
return Exploration.from_triples(uri, triples) return Exploration.from_triples(uri, triples)
elif TG_FOCUS in types: elif TG_FOCUS in types:
return Focus.from_triples(uri, triples) return Focus.from_triples(uri, triples)
elif TG_SYNTHESIS in types: elif TG_SYNTHESIS in types:
return Synthesis.from_triples(uri, triples) return Synthesis.from_triples(uri, triples)
elif TG_REFLECTION_TYPE in types:
return Reflection.from_triples(uri, triples)
elif TG_ANALYSIS in types: elif TG_ANALYSIS in types:
return Analysis.from_triples(uri, triples) return Analysis.from_triples(uri, triples)
elif TG_CONCLUSION in types: elif TG_CONCLUSION in types:
@ -124,16 +131,38 @@ class Question(ExplainEntity):
) )
@dataclass
class Grounding(ExplainEntity):
"""Grounding entity - concept decomposition of the query."""
concepts: List[str] = field(default_factory=list)
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Grounding":
concepts = []
for s, p, o in triples:
if p == TG_CONCEPT:
concepts.append(o)
return cls(
uri=uri,
entity_type="grounding",
concepts=concepts
)
@dataclass @dataclass
class Exploration(ExplainEntity): class Exploration(ExplainEntity):
"""Exploration entity - edges/chunks retrieved from the knowledge store.""" """Exploration entity - edges/chunks retrieved from the knowledge store."""
edge_count: int = 0 edge_count: int = 0
chunk_count: int = 0 chunk_count: int = 0
entities: List[str] = field(default_factory=list)
@classmethod @classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Exploration": def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Exploration":
edge_count = 0 edge_count = 0
chunk_count = 0 chunk_count = 0
entities = []
for s, p, o in triples: for s, p, o in triples:
if p == TG_EDGE_COUNT: if p == TG_EDGE_COUNT:
@ -146,12 +175,15 @@ class Exploration(ExplainEntity):
chunk_count = int(o) chunk_count = int(o)
except (ValueError, TypeError): except (ValueError, TypeError):
pass pass
elif p == TG_ENTITY:
entities.append(o)
return cls( return cls(
uri=uri, uri=uri,
entity_type="exploration", entity_type="exploration",
edge_count=edge_count, edge_count=edge_count,
chunk_count=chunk_count chunk_count=chunk_count,
entities=entities
) )
@ -180,94 +212,104 @@ class Focus(ExplainEntity):
@dataclass @dataclass
class Synthesis(ExplainEntity): class Synthesis(ExplainEntity):
"""Synthesis entity - the final answer.""" """Synthesis entity - the final answer."""
content: str = ""
document_uri: str = "" # Reference to librarian document document_uri: str = "" # Reference to librarian document
@classmethod @classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Synthesis": def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Synthesis":
content = ""
document_uri = "" document_uri = ""
for s, p, o in triples: for s, p, o in triples:
if p == TG_CONTENT: if p == TG_DOCUMENT:
content = o
elif p == TG_DOCUMENT:
document_uri = o document_uri = o
return cls( return cls(
uri=uri, uri=uri,
entity_type="synthesis", entity_type="synthesis",
content=content,
document_uri=document_uri document_uri=document_uri
) )
@dataclass
class Reflection(ExplainEntity):
"""Reflection entity - intermediate commentary (Thought or Observation)."""
document_uri: str = "" # Reference to content in librarian
reflection_type: str = "" # "thought" or "observation"
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Reflection":
document_uri = ""
reflection_type = ""
types = [o for s, p, o in triples if p == RDF_TYPE]
if TG_THOUGHT_TYPE in types:
reflection_type = "thought"
elif TG_OBSERVATION_TYPE in types:
reflection_type = "observation"
for s, p, o in triples:
if p == TG_DOCUMENT:
document_uri = o
return cls(
uri=uri,
entity_type="reflection",
document_uri=document_uri,
reflection_type=reflection_type
)
@dataclass @dataclass
class Analysis(ExplainEntity): class Analysis(ExplainEntity):
"""Analysis entity - one think/act/observe cycle (Agent only).""" """Analysis entity - one think/act/observe cycle (Agent only)."""
thought: str = ""
action: str = "" action: str = ""
arguments: str = "" # JSON string arguments: str = "" # JSON string
observation: str = "" thought_uri: str = "" # URI of thought sub-entity
thought_document_uri: str = "" # Reference to thought in librarian observation_uri: str = "" # URI of observation sub-entity
observation_document_uri: str = "" # Reference to observation in librarian
@classmethod @classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis": def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis":
thought = ""
action = "" action = ""
arguments = "" arguments = ""
observation = "" thought_uri = ""
thought_document_uri = "" observation_uri = ""
observation_document_uri = ""
for s, p, o in triples: for s, p, o in triples:
if p == TG_THOUGHT: if p == TG_ACTION:
thought = o
elif p == TG_ACTION:
action = o action = o
elif p == TG_ARGUMENTS: elif p == TG_ARGUMENTS:
arguments = o arguments = o
elif p == TG_THOUGHT:
thought_uri = o
elif p == TG_OBSERVATION: elif p == TG_OBSERVATION:
observation = o observation_uri = o
elif p == TG_THOUGHT_DOCUMENT:
thought_document_uri = o
elif p == TG_OBSERVATION_DOCUMENT:
observation_document_uri = o
return cls( return cls(
uri=uri, uri=uri,
entity_type="analysis", entity_type="analysis",
thought=thought,
action=action, action=action,
arguments=arguments, arguments=arguments,
observation=observation, thought_uri=thought_uri,
thought_document_uri=thought_document_uri, observation_uri=observation_uri
observation_document_uri=observation_document_uri
) )
@dataclass @dataclass
class Conclusion(ExplainEntity): class Conclusion(ExplainEntity):
"""Conclusion entity - final answer (Agent only).""" """Conclusion entity - final answer (Agent only)."""
answer: str = ""
document_uri: str = "" # Reference to librarian document document_uri: str = "" # Reference to librarian document
@classmethod @classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Conclusion": def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Conclusion":
answer = ""
document_uri = "" document_uri = ""
for s, p, o in triples: for s, p, o in triples:
if p == TG_ANSWER: if p == TG_DOCUMENT:
answer = o
elif p == TG_DOCUMENT:
document_uri = o document_uri = o
return cls( return cls(
uri=uri, uri=uri,
entity_type="conclusion", entity_type="conclusion",
answer=answer,
document_uri=document_uri document_uri=document_uri
) )
@ -543,42 +585,29 @@ class ExplainabilityClient:
o_label = self.resolve_label(edge.get("o", ""), user, collection) o_label = self.resolve_label(edge.get("o", ""), user, collection)
return (s_label, p_label, o_label) return (s_label, p_label, o_label)
def fetch_synthesis_content( def fetch_document_content(
self, self,
synthesis: Synthesis, document_uri: str,
api: Any, api: Any,
user: Optional[str] = None, user: Optional[str] = None,
max_content: int = 10000 max_content: int = 10000
) -> str: ) -> str:
""" """
Fetch the content for a Synthesis entity. Fetch content from the librarian by document URI.
If synthesis has inline content, returns that.
If synthesis has a document_uri, fetches from librarian with retry.
Args: Args:
synthesis: The Synthesis entity document_uri: The document URI in the librarian
api: TrustGraph Api instance for librarian access api: TrustGraph Api instance for librarian access
user: User identifier for librarian user: User identifier for librarian
max_content: Maximum content length to return max_content: Maximum content length to return
Returns: Returns:
The synthesis content as a string The document content as a string
""" """
# If inline content exists, use it if not document_uri:
if synthesis.content:
if len(synthesis.content) > max_content:
return synthesis.content[:max_content] + "... [truncated]"
return synthesis.content
# Otherwise fetch from librarian
if not synthesis.document_uri:
return "" return ""
# Extract document ID from URI (e.g., "urn:document:abc123" -> "abc123") doc_id = document_uri
doc_id = synthesis.document_uri
if doc_id.startswith("urn:document:"):
doc_id = doc_id[len("urn:document:"):]
# Retry fetching from librarian for eventual consistency # Retry fetching from librarian for eventual consistency
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
@ -603,129 +632,6 @@ class ExplainabilityClient:
return "" return ""
def fetch_conclusion_content(
self,
conclusion: Conclusion,
api: Any,
user: Optional[str] = None,
max_content: int = 10000
) -> str:
"""
Fetch the content for a Conclusion entity (Agent final answer).
If conclusion has inline answer, returns that.
If conclusion has a document_uri, fetches from librarian with retry.
Args:
conclusion: The Conclusion entity
api: TrustGraph Api instance for librarian access
user: User identifier for librarian
max_content: Maximum content length to return
Returns:
The conclusion answer as a string
"""
# If inline answer exists, use it
if conclusion.answer:
if len(conclusion.answer) > max_content:
return conclusion.answer[:max_content] + "... [truncated]"
return conclusion.answer
# Otherwise fetch from librarian
if not conclusion.document_uri:
return ""
# Use document URI directly (it's already a full URN)
doc_id = conclusion.document_uri
# Retry fetching from librarian for eventual consistency
for attempt in range(self.max_retries):
try:
library = api.library()
content_bytes = library.get_document_content(user=user, id=doc_id)
# Decode as text
try:
content = content_bytes.decode('utf-8')
if len(content) > max_content:
return content[:max_content] + "... [truncated]"
return content
except UnicodeDecodeError:
return f"[Binary: {len(content_bytes)} bytes]"
except Exception as e:
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
continue
return f"[Error fetching content: {e}]"
return ""
def fetch_analysis_content(
self,
analysis: Analysis,
api: Any,
user: Optional[str] = None,
max_content: int = 10000
) -> None:
"""
Fetch thought and observation content for an Analysis entity.
If analysis has inline content, uses that.
If analysis has document URIs, fetches from librarian with retry.
Modifies the analysis object in place.
Args:
analysis: The Analysis entity (modified in place)
api: TrustGraph Api instance for librarian access
user: User identifier for librarian
max_content: Maximum content length to return
"""
# Fetch thought if needed
if not analysis.thought and analysis.thought_document_uri:
doc_id = analysis.thought_document_uri
for attempt in range(self.max_retries):
try:
library = api.library()
content_bytes = library.get_document_content(user=user, id=doc_id)
try:
content = content_bytes.decode('utf-8')
if len(content) > max_content:
analysis.thought = content[:max_content] + "... [truncated]"
else:
analysis.thought = content
break
except UnicodeDecodeError:
analysis.thought = f"[Binary: {len(content_bytes)} bytes]"
break
except Exception as e:
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
continue
analysis.thought = f"[Error fetching thought: {e}]"
# Fetch observation if needed
if not analysis.observation and analysis.observation_document_uri:
doc_id = analysis.observation_document_uri
for attempt in range(self.max_retries):
try:
library = api.library()
content_bytes = library.get_document_content(user=user, id=doc_id)
try:
content = content_bytes.decode('utf-8')
if len(content) > max_content:
analysis.observation = content[:max_content] + "... [truncated]"
else:
analysis.observation = content
break
except UnicodeDecodeError:
analysis.observation = f"[Binary: {len(content_bytes)} bytes]"
break
except Exception as e:
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
continue
analysis.observation = f"[Error fetching observation: {e}]"
def fetch_graphrag_trace( def fetch_graphrag_trace(
self, self,
@ -739,7 +645,7 @@ class ExplainabilityClient:
""" """
Fetch the complete GraphRAG trace starting from a question URI. Fetch the complete GraphRAG trace starting from a question URI.
Follows the provenance chain: Question -> Exploration -> Focus -> Synthesis Follows the provenance chain: Question -> Grounding -> Exploration -> Focus -> Synthesis
Args: Args:
question_uri: The question entity URI question_uri: The question entity URI
@ -750,13 +656,14 @@ class ExplainabilityClient:
max_content: Maximum content length for synthesis max_content: Maximum content length for synthesis
Returns: Returns:
Dict with question, exploration, focus, synthesis entities Dict with question, grounding, exploration, focus, synthesis entities
""" """
if graph is None: if graph is None:
graph = "urn:graph:retrieval" graph = "urn:graph:retrieval"
trace = { trace = {
"question": None, "question": None,
"grounding": None,
"exploration": None, "exploration": None,
"focus": None, "focus": None,
"synthesis": None, "synthesis": None,
@ -768,8 +675,8 @@ class ExplainabilityClient:
return trace return trace
trace["question"] = question trace["question"] = question
# Find exploration: ?exploration prov:wasGeneratedBy question_uri # Find grounding: ?grounding prov:wasGeneratedBy question_uri
exploration_triples = self.flow.triples_query( grounding_triples = self.flow.triples_query(
p=PROV_WAS_GENERATED_BY, p=PROV_WAS_GENERATED_BY,
o=question_uri, o=question_uri,
g=graph, g=graph,
@ -778,6 +685,30 @@ class ExplainabilityClient:
limit=10 limit=10
) )
if grounding_triples:
grounding_uris = [
extract_term_value(t.get("s", {}))
for t in grounding_triples
]
for gnd_uri in grounding_uris:
grounding = self.fetch_entity(gnd_uri, graph, user, collection)
if isinstance(grounding, Grounding):
trace["grounding"] = grounding
break
if not trace["grounding"]:
return trace
# Find exploration: ?exploration prov:wasDerivedFrom grounding_uri
exploration_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=trace["grounding"].uri,
g=graph,
user=user,
collection=collection,
limit=10
)
if exploration_triples: if exploration_triples:
exploration_uris = [ exploration_uris = [
extract_term_value(t.get("s", {})) extract_term_value(t.get("s", {}))
@ -834,11 +765,6 @@ class ExplainabilityClient:
for synth_uri in synthesis_uris: for synth_uri in synthesis_uris:
synthesis = self.fetch_entity(synth_uri, graph, user, collection) synthesis = self.fetch_entity(synth_uri, graph, user, collection)
if isinstance(synthesis, Synthesis): if isinstance(synthesis, Synthesis):
# Fetch content if needed
if api and not synthesis.content and synthesis.document_uri:
synthesis.content = self.fetch_synthesis_content(
synthesis, api, user, max_content
)
trace["synthesis"] = synthesis trace["synthesis"] = synthesis
break break
@ -928,11 +854,6 @@ class ExplainabilityClient:
for synth_uri in synthesis_uris: for synth_uri in synthesis_uris:
synthesis = self.fetch_entity(synth_uri, graph, user, collection) synthesis = self.fetch_entity(synth_uri, graph, user, collection)
if isinstance(synthesis, Synthesis): if isinstance(synthesis, Synthesis):
# Fetch content if needed
if api and not synthesis.content and synthesis.document_uri:
synthesis.content = self.fetch_synthesis_content(
synthesis, api, user, max_content
)
trace["synthesis"] = synthesis trace["synthesis"] = synthesis
break break
@ -978,20 +899,43 @@ class ExplainabilityClient:
return trace return trace
trace["question"] = question trace["question"] = question
# Follow the chain of wasDerivedFrom # Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after
current_uri = session_uri current_uri = session_uri
is_first = True
max_iterations = 50 # Safety limit max_iterations = 50 # Safety limit
for _ in range(max_iterations): for _ in range(max_iterations):
# Find entity derived from current # First hop uses wasGeneratedBy (entity←activity),
derived_triples = self.flow.triples_query( # subsequent hops use wasDerivedFrom (entity←entity)
p=PROV_WAS_DERIVED_FROM, if is_first:
o=current_uri, derived_triples = self.flow.triples_query(
g=graph, p=PROV_WAS_GENERATED_BY,
user=user, o=current_uri,
collection=collection, g=graph,
limit=10 user=user,
) collection=collection,
limit=10
)
# Fall back to wasDerivedFrom for backwards compatibility
if not derived_triples:
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
is_first = False
else:
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
if not derived_triples: if not derived_triples:
break break
@ -1003,19 +947,9 @@ class ExplainabilityClient:
entity = self.fetch_entity(derived_uri, graph, user, collection) entity = self.fetch_entity(derived_uri, graph, user, collection)
if isinstance(entity, Analysis): if isinstance(entity, Analysis):
# Fetch thought/observation content from librarian if needed
if api:
self.fetch_analysis_content(
entity, api, user=user, max_content=max_content
)
trace["iterations"].append(entity) trace["iterations"].append(entity)
current_uri = derived_uri current_uri = derived_uri
elif isinstance(entity, Conclusion): elif isinstance(entity, Conclusion):
# Fetch answer content from librarian if needed
if api and not entity.answer and entity.document_uri:
entity.answer = self.fetch_conclusion_content(
entity, api, user=user, max_content=max_content
)
trace["conclusion"] = entity trace["conclusion"] = entity
break break
else: else:

View file

@ -1,6 +1,6 @@
from . request_response_spec import RequestResponse, RequestResponseSpec from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL, TRIPLE
from .. knowledge import Uri, Literal from .. knowledge import Uri, Literal
@ -22,9 +22,11 @@ def to_value(x):
def from_value(x): def from_value(x):
"""Convert Uri, Literal, or string to schema Term.""" """Convert Uri, Literal, string, or Term to schema Term."""
if x is None: if x is None:
return None return None
if isinstance(x, Term):
return x
if isinstance(x, Uri): if isinstance(x, Uri):
return Term(type=IRI, iri=str(x)) return Term(type=IRI, iri=str(x))
elif isinstance(x, Literal): elif isinstance(x, Literal):
@ -41,7 +43,7 @@ def from_value(x):
class TriplesClient(RequestResponse): class TriplesClient(RequestResponse):
async def query(self, s=None, p=None, o=None, limit=20, async def query(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default", user="trustgraph", collection="default",
timeout=30): timeout=30, g=None):
resp = await self.request( resp = await self.request(
TriplesQueryRequest( TriplesQueryRequest(
@ -51,6 +53,7 @@ class TriplesClient(RequestResponse):
limit = limit, limit = limit,
user = user, user = user,
collection = collection, collection = collection,
g = g,
), ),
timeout=timeout timeout=timeout
) )
@ -68,7 +71,7 @@ class TriplesClient(RequestResponse):
async def query_stream(self, s=None, p=None, o=None, limit=20, async def query_stream(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default", user="trustgraph", collection="default",
batch_size=20, timeout=30, batch_size=20, timeout=30,
batch_callback=None): batch_callback=None, g=None):
""" """
Streaming triple query - calls callback for each batch as it arrives. Streaming triple query - calls callback for each batch as it arrives.
@ -80,6 +83,8 @@ class TriplesClient(RequestResponse):
batch_size: Triples per batch batch_size: Triples per batch
timeout: Request timeout in seconds timeout: Request timeout in seconds
batch_callback: Async callback(batch, is_final) called for each batch batch_callback: Async callback(batch, is_final) called for each batch
g: Graph filter. ""=default graph only, None=all graphs,
or a specific graph IRI.
Returns: Returns:
List[Triple]: All triples (flattened) if no callback provided List[Triple]: All triples (flattened) if no callback provided
@ -112,6 +117,7 @@ class TriplesClient(RequestResponse):
collection=collection, collection=collection,
streaming=True, streaming=True,
batch_size=batch_size, batch_size=batch_size,
g=g,
), ),
timeout=timeout, timeout=timeout,
recipient=recipient, recipient=recipient,

View file

@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
triple_limit=int(data.get("triple-limit", 30)), triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)), max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2)), max_path_length=int(data.get("max-path-length", 2)),
edge_limit=int(data.get("edge-limit", 25)),
streaming=data.get("streaming", False) streaming=data.get("streaming", False)
) )
@ -96,6 +97,7 @@ class GraphRagRequestTranslator(MessageTranslator):
"triple-limit": obj.triple_limit, "triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size, "max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length, "max-path-length": obj.max_path_length,
"edge-limit": obj.edge_limit,
"streaming": getattr(obj, "streaming", False) "streaming": getattr(obj, "streaming", False)
} }

View file

@ -42,15 +42,19 @@ from . uris import (
agent_uri, agent_uri,
# Query-time provenance URIs (GraphRAG) # Query-time provenance URIs (GraphRAG)
question_uri, question_uri,
grounding_uri,
exploration_uri, exploration_uri,
focus_uri, focus_uri,
synthesis_uri, synthesis_uri,
# Agent provenance URIs # Agent provenance URIs
agent_session_uri, agent_session_uri,
agent_iteration_uri, agent_iteration_uri,
agent_thought_uri,
agent_observation_uri,
agent_final_uri, agent_final_uri,
# Document RAG provenance URIs # Document RAG provenance URIs
docrag_question_uri, docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri, docrag_exploration_uri,
docrag_synthesis_uri, docrag_synthesis_uri,
) )
@ -74,18 +78,19 @@ from . namespaces import (
# Extraction provenance entity types # Extraction provenance entity types
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG) # Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_CONTENT, TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
# Query-time provenance predicates (DocumentRAG) # Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK, TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types # Explainability entity types
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION, TG_ANALYSIS, TG_CONCLUSION,
# Unifying types
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
# Question subtypes (to distinguish retrieval mechanism) # Question subtypes (to distinguish retrieval mechanism)
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
# Agent provenance predicates # Agent provenance predicates
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
# Agent document references
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
# Document reference predicate # Document reference predicate
TG_DOCUMENT, TG_DOCUMENT,
# Named graphs # Named graphs
@ -99,6 +104,7 @@ from . triples import (
subgraph_provenance_triples, subgraph_provenance_triples,
# Query-time provenance triple builders (GraphRAG) # Query-time provenance triple builders (GraphRAG)
question_triples, question_triples,
grounding_triples,
exploration_triples, exploration_triples,
focus_triples, focus_triples,
synthesis_triples, synthesis_triples,
@ -139,15 +145,19 @@ __all__ = [
"agent_uri", "agent_uri",
# Query-time provenance URIs # Query-time provenance URIs
"question_uri", "question_uri",
"grounding_uri",
"exploration_uri", "exploration_uri",
"focus_uri", "focus_uri",
"synthesis_uri", "synthesis_uri",
# Agent provenance URIs # Agent provenance URIs
"agent_session_uri", "agent_session_uri",
"agent_iteration_uri", "agent_iteration_uri",
"agent_thought_uri",
"agent_observation_uri",
"agent_final_uri", "agent_final_uri",
# Document RAG provenance URIs # Document RAG provenance URIs
"docrag_question_uri", "docrag_question_uri",
"docrag_grounding_uri",
"docrag_exploration_uri", "docrag_exploration_uri",
"docrag_synthesis_uri", "docrag_synthesis_uri",
# Namespaces # Namespaces
@ -164,18 +174,19 @@ __all__ = [
# Extraction provenance entity types # Extraction provenance entity types
"TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE", "TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE",
# Query-time provenance predicates (GraphRAG) # Query-time provenance predicates (GraphRAG)
"TG_QUERY", "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_CONTENT", "TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
# Query-time provenance predicates (DocumentRAG) # Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Explainability entity types # Explainability entity types
"TG_QUESTION", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", "TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
"TG_ANALYSIS", "TG_CONCLUSION", "TG_ANALYSIS", "TG_CONCLUSION",
# Unifying types
"TG_ANSWER_TYPE", "TG_REFLECTION_TYPE", "TG_THOUGHT_TYPE", "TG_OBSERVATION_TYPE",
# Question subtypes # Question subtypes
"TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION", "TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION",
# Agent provenance predicates # Agent provenance predicates
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER", "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION",
# Agent document references
"TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT",
# Document reference predicate # Document reference predicate
"TG_DOCUMENT", "TG_DOCUMENT",
# Named graphs # Named graphs
@ -186,6 +197,7 @@ __all__ = [
"subgraph_provenance_triples", "subgraph_provenance_triples",
# Query-time provenance triple builders (GraphRAG) # Query-time provenance triple builders (GraphRAG)
"question_triples", "question_triples",
"grounding_triples",
"exploration_triples", "exploration_triples",
"focus_triples", "focus_triples",
"synthesis_triples", "synthesis_triples",

View file

@ -15,10 +15,11 @@ from .. schema import Triple, Term, IRI, LITERAL
from . namespaces import ( from . namespaces import (
RDF_TYPE, RDFS_LABEL, RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME, PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_AGENT_QUESTION, TG_AGENT_QUESTION,
) )
@ -73,12 +74,13 @@ def agent_session_triples(
def agent_iteration_triples( def agent_iteration_triples(
iteration_uri: str, iteration_uri: str,
parent_uri: str, question_uri: Optional[str] = None,
thought: str = "", previous_uri: Optional[str] = None,
action: str = "", action: str = "",
arguments: Dict[str, Any] = None, arguments: Dict[str, Any] = None,
observation: str = "", thought_uri: Optional[str] = None,
thought_document_id: Optional[str] = None, thought_document_id: Optional[str] = None,
observation_uri: Optional[str] = None,
observation_document_id: Optional[str] = None, observation_document_id: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
@ -86,19 +88,22 @@ def agent_iteration_triples(
Creates: Creates:
- Entity declaration with tg:Analysis type - Entity declaration with tg:Analysis type
- wasDerivedFrom link to parent (previous iteration or session) - wasGeneratedBy link to question (if first iteration)
- Thought, action, arguments, and observation data - wasDerivedFrom link to previous iteration (if not first)
- Document references for thought/observation when stored in librarian - Action and arguments metadata
- Thought sub-entity (tg:Reflection, tg:Thought) with librarian document
- Observation sub-entity (tg:Reflection, tg:Observation) with librarian document
Args: Args:
iteration_uri: URI of this iteration (from agent_iteration_uri) iteration_uri: URI of this iteration (from agent_iteration_uri)
parent_uri: URI of the parent (previous iteration or session) question_uri: URI of the question activity (for first iteration)
thought: The agent's reasoning/thought (used if thought_document_id not provided) previous_uri: URI of the previous iteration (for subsequent iterations)
action: The tool/action name action: The tool/action name
arguments: Arguments passed to the tool (will be JSON-encoded) arguments: Arguments passed to the tool (will be JSON-encoded)
observation: The result/observation from the tool (used if observation_document_id not provided) thought_uri: URI for the thought sub-entity
thought_document_id: Optional document URI for thought in librarian (preferred) thought_document_id: Document URI for thought in librarian
observation_document_id: Optional document URI for observation in librarian (preferred) observation_uri: URI for the observation sub-entity
observation_document_id: Document URI for observation in librarian
Returns: Returns:
List of Triple objects List of Triple objects
@ -110,45 +115,70 @@ def agent_iteration_triples(
_triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)), _triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)),
_triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")), _triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")),
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
_triple(iteration_uri, TG_ACTION, _literal(action)), _triple(iteration_uri, TG_ACTION, _literal(action)),
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
] ]
# Thought: use document reference or inline if question_uri:
if thought_document_id: triples.append(
triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id))) _triple(iteration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri))
elif thought: )
triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought))) elif previous_uri:
triples.append(
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri))
)
# Observation: use document reference or inline # Thought sub-entity
if observation_document_id: if thought_uri:
triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id))) triples.extend([
elif observation: _triple(iteration_uri, TG_THOUGHT, _iri(thought_uri)),
triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation))) _triple(thought_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)),
_triple(thought_uri, RDF_TYPE, _iri(TG_THOUGHT_TYPE)),
_triple(thought_uri, RDFS_LABEL, _literal("Thought")),
_triple(thought_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)),
])
if thought_document_id:
triples.append(
_triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id))
)
# Observation sub-entity
if observation_uri:
triples.extend([
_triple(iteration_uri, TG_OBSERVATION, _iri(observation_uri)),
_triple(observation_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)),
_triple(observation_uri, RDF_TYPE, _iri(TG_OBSERVATION_TYPE)),
_triple(observation_uri, RDFS_LABEL, _literal("Observation")),
_triple(observation_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)),
])
if observation_document_id:
triples.append(
_triple(observation_uri, TG_DOCUMENT, _iri(observation_document_id))
)
return triples return triples
def agent_final_triples( def agent_final_triples(
final_uri: str, final_uri: str,
parent_uri: str, question_uri: Optional[str] = None,
answer: str = "", previous_uri: Optional[str] = None,
document_id: Optional[str] = None, document_id: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for an agent final answer (Conclusion). Build triples for an agent final answer (Conclusion).
Creates: Creates:
- Entity declaration with tg:Conclusion type - Entity declaration with tg:Conclusion and tg:Answer types
- wasDerivedFrom link to parent (last iteration or session) - wasGeneratedBy link to question (if no iterations)
- Either document reference (if document_id provided) or inline answer - wasDerivedFrom link to last iteration (if iterations exist)
- Document reference to librarian
Args: Args:
final_uri: URI of the final answer (from agent_final_uri) final_uri: URI of the final answer (from agent_final_uri)
parent_uri: URI of the parent (last iteration or session if no iterations) question_uri: URI of the question activity (if no iterations)
answer: The final answer text (used if document_id not provided) previous_uri: URI of the last iteration (if iterations exist)
document_id: Optional document URI in librarian (preferred) document_id: Librarian document ID for the answer content
Returns: Returns:
List of Triple objects List of Triple objects
@ -156,15 +186,20 @@ def agent_final_triples(
triples = [ triples = [
_triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)), _triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)),
_triple(final_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
_triple(final_uri, RDFS_LABEL, _literal("Conclusion")), _triple(final_uri, RDFS_LABEL, _literal("Conclusion")),
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
] ]
if question_uri:
triples.append(
_triple(final_uri, PROV_WAS_GENERATED_BY, _iri(question_uri))
)
elif previous_uri:
triples.append(
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri))
)
if document_id: if document_id:
# Store reference to document in librarian (as IRI)
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
elif answer:
# Fallback: store inline answer
triples.append(_triple(final_uri, TG_ANSWER, _literal(answer)))
return triples return triples

View file

@ -60,11 +60,12 @@ TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength"
# Query-time provenance predicates (GraphRAG) # Query-time provenance predicates (GraphRAG)
TG_QUERY = TG + "query" TG_QUERY = TG + "query"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_EDGE_COUNT = TG + "edgeCount" TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge" TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge" TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning" TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document" # Reference to document in librarian TG_DOCUMENT = TG + "document" # Reference to document in librarian
# Query-time provenance predicates (DocumentRAG) # Query-time provenance predicates (DocumentRAG)
@ -79,27 +80,29 @@ TG_SUBGRAPH_TYPE = TG + "Subgraph"
# Explainability entity types (shared) # Explainability entity types (shared)
TG_QUESTION = TG + "Question" TG_QUESTION = TG + "Question"
TG_GROUNDING = TG + "Grounding"
TG_EXPLORATION = TG + "Exploration" TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus" TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis" TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis" TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion" TG_CONCLUSION = TG + "Conclusion"
# Unifying types for answer and intermediate commentary
TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion)
TG_REFLECTION_TYPE = TG + "Reflection" # Intermediate commentary (Thought, Observation)
TG_THOUGHT_TYPE = TG + "Thought" # Agent reasoning
TG_OBSERVATION_TYPE = TG + "Observation" # Agent tool result
# Question subtypes (to distinguish retrieval mechanism) # Question subtypes (to distinguish retrieval mechanism)
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
TG_AGENT_QUESTION = TG + "AgentQuestion" TG_AGENT_QUESTION = TG + "AgentQuestion"
# Agent provenance predicates # Agent provenance predicates
TG_THOUGHT = TG + "thought" TG_THOUGHT = TG + "thought" # Links iteration to thought sub-entity
TG_ACTION = TG + "action" TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments" TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation" TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity
TG_ANSWER = TG + "answer"
# Agent document references (for librarian storage)
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
# Named graph URIs for RDF datasets # Named graph URIs for RDF datasets
# These separate different types of data while keeping them in the same collection # These separate different types of data while keeping them in the same collection

View file

@ -20,12 +20,15 @@ from . namespaces import (
# Extraction provenance entity types # Extraction provenance entity types
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG) # Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_CONTENT, TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT, TG_DOCUMENT,
# Query-time provenance predicates (DocumentRAG) # Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK, TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types # Explainability entity types
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
# Unifying types
TG_ANSWER_TYPE,
# Question subtypes # Question subtypes
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
) )
@ -347,35 +350,78 @@ def question_triples(
] ]
def grounding_triples(
grounding_uri: str,
question_uri: str,
concepts: List[str],
) -> List[Triple]:
"""
Build triples for a grounding entity (concept decomposition of query).
Creates:
- Entity declaration for grounding
- wasGeneratedBy link to question
- Concept literals for each extracted concept
Args:
grounding_uri: URI of the grounding entity (from grounding_uri)
question_uri: URI of the parent question
concepts: List of concept strings extracted from the query
Returns:
List of Triple objects
"""
triples = [
_triple(grounding_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(grounding_uri, RDF_TYPE, _iri(TG_GROUNDING)),
_triple(grounding_uri, RDFS_LABEL, _literal("Grounding")),
_triple(grounding_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)),
]
for concept in concepts:
triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept)))
return triples
def exploration_triples( def exploration_triples(
exploration_uri: str, exploration_uri: str,
question_uri: str, grounding_uri: str,
edge_count: int, edge_count: int,
entities: Optional[List[str]] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for an exploration entity (all edges retrieved from subgraph). Build triples for an exploration entity (all edges retrieved from subgraph).
Creates: Creates:
- Entity declaration for exploration - Entity declaration for exploration
- wasGeneratedBy link to question - wasDerivedFrom link to grounding
- Edge count metadata - Edge count metadata
- Entity IRIs for each seed entity
Args: Args:
exploration_uri: URI of the exploration entity (from exploration_uri) exploration_uri: URI of the exploration entity (from exploration_uri)
question_uri: URI of the parent question grounding_uri: URI of the parent grounding entity
edge_count: Number of edges retrieved edge_count: Number of edges retrieved
entities: Optional list of seed entity URIs
Returns: Returns:
List of Triple objects List of Triple objects
""" """
return [ triples = [
_triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)), _triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)),
_triple(exploration_uri, RDFS_LABEL, _literal("Exploration")), _triple(exploration_uri, RDFS_LABEL, _literal("Exploration")),
_triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)), _triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)),
_triple(exploration_uri, TG_EDGE_COUNT, _literal(edge_count)), _triple(exploration_uri, TG_EDGE_COUNT, _literal(edge_count)),
] ]
if entities:
for entity in entities:
triples.append(_triple(exploration_uri, TG_ENTITY, _iri(entity)))
return triples
def _quoted_triple(s: str, p: str, o: str) -> Term: def _quoted_triple(s: str, p: str, o: str) -> Term:
"""Create a quoted triple term (RDF-star) from string values.""" """Create a quoted triple term (RDF-star) from string values."""
@ -454,22 +500,20 @@ def focus_triples(
def synthesis_triples( def synthesis_triples(
synthesis_uri: str, synthesis_uri: str,
focus_uri: str, focus_uri: str,
answer_text: str = "",
document_id: Optional[str] = None, document_id: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for a synthesis entity (final answer text). Build triples for a synthesis entity (final answer).
Creates: Creates:
- Entity declaration for synthesis - Entity declaration for synthesis with tg:Answer type
- wasDerivedFrom link to focus - wasDerivedFrom link to focus
- Either document reference (if document_id provided) or inline content - Document reference to librarian
Args: Args:
synthesis_uri: URI of the synthesis entity (from synthesis_uri) synthesis_uri: URI of the synthesis entity (from synthesis_uri)
focus_uri: URI of the parent focus entity focus_uri: URI of the parent focus entity
answer_text: The synthesized answer text (used if no document_id) document_id: Librarian document ID for the answer content
document_id: Optional librarian document ID (preferred over inline content)
Returns: Returns:
List of Triple objects List of Triple objects
@ -477,16 +521,13 @@ def synthesis_triples(
triples = [ triples = [
_triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)), _triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
_triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")), _triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")),
_triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(focus_uri)), _triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(focus_uri)),
] ]
if document_id: if document_id:
# Store reference to document in librarian (as IRI)
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
elif answer_text:
# Fallback: store inline content
triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text)))
return triples return triples
@ -533,7 +574,7 @@ def docrag_question_triples(
def docrag_exploration_triples( def docrag_exploration_triples(
exploration_uri: str, exploration_uri: str,
question_uri: str, grounding_uri: str,
chunk_count: int, chunk_count: int,
chunk_ids: Optional[List[str]] = None, chunk_ids: Optional[List[str]] = None,
) -> List[Triple]: ) -> List[Triple]:
@ -542,12 +583,12 @@ def docrag_exploration_triples(
Creates: Creates:
- Entity declaration with tg:Exploration type - Entity declaration with tg:Exploration type
- wasGeneratedBy link to question - wasDerivedFrom link to grounding
- Chunk count and optional chunk references - Chunk count and optional chunk references
Args: Args:
exploration_uri: URI of the exploration entity exploration_uri: URI of the exploration entity
question_uri: URI of the parent question grounding_uri: URI of the parent grounding entity
chunk_count: Number of chunks retrieved chunk_count: Number of chunks retrieved
chunk_ids: Optional list of chunk URIs/IDs chunk_ids: Optional list of chunk URIs/IDs
@ -558,7 +599,7 @@ def docrag_exploration_triples(
_triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)), _triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)),
_triple(exploration_uri, RDFS_LABEL, _literal("Exploration")), _triple(exploration_uri, RDFS_LABEL, _literal("Exploration")),
_triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)), _triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)),
_triple(exploration_uri, TG_CHUNK_COUNT, _literal(chunk_count)), _triple(exploration_uri, TG_CHUNK_COUNT, _literal(chunk_count)),
] ]
@ -573,22 +614,20 @@ def docrag_exploration_triples(
def docrag_synthesis_triples( def docrag_synthesis_triples(
synthesis_uri: str, synthesis_uri: str,
exploration_uri: str, exploration_uri: str,
answer_text: str = "",
document_id: Optional[str] = None, document_id: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for a document RAG synthesis entity (final answer). Build triples for a document RAG synthesis entity (final answer).
Creates: Creates:
- Entity declaration with tg:Synthesis type - Entity declaration with tg:Synthesis and tg:Answer types
- wasDerivedFrom link to exploration (skips focus step) - wasDerivedFrom link to exploration (skips focus step)
- Either document reference or inline content - Document reference to librarian
Args: Args:
synthesis_uri: URI of the synthesis entity synthesis_uri: URI of the synthesis entity
exploration_uri: URI of the parent exploration entity exploration_uri: URI of the parent exploration entity
answer_text: The synthesized answer text (used if no document_id) document_id: Librarian document ID for the answer content
document_id: Optional librarian document ID (preferred over inline content)
Returns: Returns:
List of Triple objects List of Triple objects
@ -596,13 +635,12 @@ def docrag_synthesis_triples(
triples = [ triples = [
_triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)), _triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
_triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")), _triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")),
_triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), _triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
] ]
if document_id: if document_id:
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
elif answer_text:
triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text)))
return triples return triples

View file

@ -68,6 +68,7 @@ def agent_uri(component_name: str) -> str:
# #
# Terminology: # Terminology:
# Question - What was asked, the anchor for everything # Question - What was asked, the anchor for everything
# Grounding - Decomposing the question into concepts
# Exploration - Casting wide, what do we know about this space # Exploration - Casting wide, what do we know about this space
# Focus - Closing down, what's actually relevant here # Focus - Closing down, what's actually relevant here
# Synthesis - Weaving the relevant pieces into an answer # Synthesis - Weaving the relevant pieces into an answer
@ -87,6 +88,19 @@ def question_uri(session_id: str = None) -> str:
return f"urn:trustgraph:question:{session_id}" return f"urn:trustgraph:question:{session_id}"
def grounding_uri(session_id: str) -> str:
"""
Generate URI for a grounding entity (concept decomposition of query).
Args:
session_id: The session UUID (same as question_uri).
Returns:
URN in format: urn:trustgraph:prov:grounding:{uuid}
"""
return f"urn:trustgraph:prov:grounding:{session_id}"
def exploration_uri(session_id: str) -> str: def exploration_uri(session_id: str) -> str:
""" """
Generate URI for an exploration entity (edges retrieved from subgraph). Generate URI for an exploration entity (edges retrieved from subgraph).
@ -173,6 +187,34 @@ def agent_iteration_uri(session_id: str, iteration_num: int) -> str:
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}" return f"urn:trustgraph:agent:{session_id}/i{iteration_num}"
def agent_thought_uri(session_id: str, iteration_num: int) -> str:
"""
Generate URI for an agent thought sub-entity.
Args:
session_id: The session UUID.
iteration_num: 1-based iteration number.
Returns:
URN in format: urn:trustgraph:agent:{uuid}/i{num}/thought
"""
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
def agent_observation_uri(session_id: str, iteration_num: int) -> str:
"""
Generate URI for an agent observation sub-entity.
Args:
session_id: The session UUID.
iteration_num: 1-based iteration number.
Returns:
URN in format: urn:trustgraph:agent:{uuid}/i{num}/observation
"""
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
def agent_final_uri(session_id: str) -> str: def agent_final_uri(session_id: str) -> str:
""" """
Generate URI for an agent final answer. Generate URI for an agent final answer.
@ -205,6 +247,19 @@ def docrag_question_uri(session_id: str = None) -> str:
return f"urn:trustgraph:docrag:{session_id}" return f"urn:trustgraph:docrag:{session_id}"
def docrag_grounding_uri(session_id: str) -> str:
"""
Generate URI for a document RAG grounding entity (concept decomposition).
Args:
session_id: The session UUID.
Returns:
URN in format: urn:trustgraph:docrag:{uuid}/grounding
"""
return f"urn:trustgraph:docrag:{session_id}/grounding"
def docrag_exploration_uri(session_id: str) -> str: def docrag_exploration_uri(session_id: str) -> str:
""" """
Generate URI for a document RAG exploration entity (chunks retrieved). Generate URI for a document RAG exploration entity (chunks retrieved).

View file

@ -25,6 +25,8 @@ from . namespaces import (
TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL, TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL,
TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH, TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_CONCEPT, TG_ENTITY, TG_GROUNDING,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
) )
@ -80,6 +82,11 @@ TG_CLASS_LABELS = [
_label_triple(TG_PAGE_TYPE, "Page"), _label_triple(TG_PAGE_TYPE, "Page"),
_label_triple(TG_CHUNK_TYPE, "Chunk"), _label_triple(TG_CHUNK_TYPE, "Chunk"),
_label_triple(TG_SUBGRAPH_TYPE, "Subgraph"), _label_triple(TG_SUBGRAPH_TYPE, "Subgraph"),
_label_triple(TG_GROUNDING, "Grounding"),
_label_triple(TG_ANSWER_TYPE, "Answer"),
_label_triple(TG_REFLECTION_TYPE, "Reflection"),
_label_triple(TG_THOUGHT_TYPE, "Thought"),
_label_triple(TG_OBSERVATION_TYPE, "Observation"),
] ]
# TrustGraph predicate labels # TrustGraph predicate labels
@ -100,6 +107,8 @@ TG_PREDICATE_LABELS = [
_label_triple(TG_SOURCE_TEXT, "source text"), _label_triple(TG_SOURCE_TEXT, "source text"),
_label_triple(TG_SOURCE_CHAR_OFFSET, "source character offset"), _label_triple(TG_SOURCE_CHAR_OFFSET, "source character offset"),
_label_triple(TG_SOURCE_CHAR_LENGTH, "source character length"), _label_triple(TG_SOURCE_CHAR_LENGTH, "source character length"),
_label_triple(TG_CONCEPT, "concept"),
_label_triple(TG_ENTITY, "entity"),
] ]

View file

@ -15,6 +15,7 @@ class GraphRagQuery:
triple_limit: int = 0 triple_limit: int = 0
max_subgraph_size: int = 0 max_subgraph_size: int = 0
max_path_length: int = 0 max_path_length: int = 0
edge_limit: int = 0
streaming: bool = False streaming: bool = False
@dataclass @dataclass

View file

@ -202,16 +202,17 @@ def question_explainable(
elif isinstance(entity, Analysis): elif isinstance(entity, Analysis):
print(f"\n [iteration] {prov_id}", file=sys.stderr) print(f"\n [iteration] {prov_id}", file=sys.stderr)
if entity.thought:
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
print(f" Thought: {thought_short}", file=sys.stderr)
if entity.action: if entity.action:
print(f" Action: {entity.action}", file=sys.stderr) print(f" Action: {entity.action}", file=sys.stderr)
if entity.thought_uri:
print(f" Thought: {entity.thought_uri}", file=sys.stderr)
if entity.observation_uri:
print(f" Observation: {entity.observation_uri}", file=sys.stderr)
elif isinstance(entity, Conclusion): elif isinstance(entity, Conclusion):
print(f"\n [conclusion] {prov_id}", file=sys.stderr) print(f"\n [conclusion] {prov_id}", file=sys.stderr)
if entity.answer: if entity.document_uri:
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr) print(f" Document: {entity.document_uri}", file=sys.stderr)
else: else:
if debug: if debug:

View file

@ -11,6 +11,7 @@ from trustgraph.api import (
RAGChunk, RAGChunk,
ProvenanceEvent, ProvenanceEvent,
Question, Question,
Grounding,
Exploration, Exploration,
Synthesis, Synthesis,
) )
@ -68,6 +69,12 @@ def question_explainable(
if entity.timestamp: if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr) print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Grounding):
print(f"\n [grounding] {prov_id}", file=sys.stderr)
if entity.concepts:
for concept in entity.concepts:
print(f" Concept: {concept}", file=sys.stderr)
elif isinstance(entity, Exploration): elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr) print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.chunk_count: if entity.chunk_count:
@ -75,8 +82,8 @@ def question_explainable(
elif isinstance(entity, Synthesis): elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr) print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content: if entity.document_uri:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr) print(f" Document: {entity.document_uri}", file=sys.stderr)
else: else:
if debug: if debug:

View file

@ -14,6 +14,7 @@ from trustgraph.api import (
RAGChunk, RAGChunk,
ProvenanceEvent, ProvenanceEvent,
Question, Question,
Grounding,
Exploration, Exploration,
Focus, Focus,
Synthesis, Synthesis,
@ -31,11 +32,13 @@ default_max_path_length = 2
# Provenance predicates # Provenance predicates
TG = "https://trustgraph.ai/ns/" TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query" TG_QUERY = TG + "query"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_EDGE_COUNT = TG + "edgeCount" TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge" TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge" TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning" TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content" TG_DOCUMENT = TG + "document"
TG_CONTAINS = TG + "contains" TG_CONTAINS = TG + "contains"
PROV = "http://www.w3.org/ns/prov#" PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime" PROV_STARTED_AT_TIME = PROV + "startedAtTime"
@ -47,6 +50,8 @@ def _get_event_type(prov_id):
"""Extract event type from provenance_id""" """Extract event type from provenance_id"""
if "question" in prov_id: if "question" in prov_id:
return "question" return "question"
elif "grounding" in prov_id:
return "grounding"
elif "exploration" in prov_id: elif "exploration" in prov_id:
return "exploration" return "exploration"
elif "focus" in prov_id: elif "focus" in prov_id:
@ -68,8 +73,16 @@ def _format_provenance_details(event_type, triples):
elif p == PROV_STARTED_AT_TIME: elif p == PROV_STARTED_AT_TIME:
lines.append(f" Time: {o}") lines.append(f" Time: {o}")
elif event_type == "grounding":
# Show extracted concepts
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
if concepts:
lines.append(f" Concepts: {len(concepts)}")
for concept in concepts:
lines.append(f" - {concept}")
elif event_type == "exploration": elif event_type == "exploration":
# Show edge count # Show edge count (seed entities resolved separately with labels)
for s, p, o in triples: for s, p, o in triples:
if p == TG_EDGE_COUNT: if p == TG_EDGE_COUNT:
lines.append(f" Edges explored: {o}") lines.append(f" Edges explored: {o}")
@ -85,10 +98,10 @@ def _format_provenance_details(event_type, triples):
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)") lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
elif event_type == "synthesis": elif event_type == "synthesis":
# Show content length (not full content - it's already streamed) # Show document reference (content already streamed)
for s, p, o in triples: for s, p, o in triples:
if p == TG_CONTENT: if p == TG_DOCUMENT:
lines.append(f" Synthesis length: {len(o)} chars") lines.append(f" Document: {o}")
return lines return lines
@ -542,6 +555,18 @@ async def _question_explainable(
for line in details: for line in details:
print(line, file=sys.stderr) print(line, file=sys.stderr)
# For exploration events, resolve entity labels
if event_type == "exploration":
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
if entity_iris:
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
for iri in entity_iris:
label = await _query_label(
ws_url, flow_id, iri, user, collection,
label_cache, debug=debug
)
print(f" - {label}", file=sys.stderr)
# For focus events, query each edge selection for details # For focus events, query each edge selection for details
if event_type == "focus": if event_type == "focus":
for s, p, o in triples: for s, p, o in triples:
@ -660,10 +685,22 @@ def _question_explainable_api(
if entity.timestamp: if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr) print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Grounding):
print(f"\n [grounding] {prov_id}", file=sys.stderr)
if entity.concepts:
print(f" Concepts: {len(entity.concepts)}", file=sys.stderr)
for concept in entity.concepts:
print(f" - {concept}", file=sys.stderr)
elif isinstance(entity, Exploration): elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr) print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.edge_count: if entity.edge_count:
print(f" Edges explored: {entity.edge_count}", file=sys.stderr) print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
if entity.entities:
print(f" Seed entities: {len(entity.entities)}", file=sys.stderr)
for ent in entity.entities:
label = explain_client.resolve_label(ent, user, collection)
print(f" - {label}", file=sys.stderr)
elif isinstance(entity, Focus): elif isinstance(entity, Focus):
print(f"\n [focus] {prov_id}", file=sys.stderr) print(f"\n [focus] {prov_id}", file=sys.stderr)
@ -691,8 +728,8 @@ def _question_explainable_api(
elif isinstance(entity, Synthesis): elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr) print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content: if entity.document_uri:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr) print(f" Document: {entity.document_uri}", file=sys.stderr)
else: else:
if debug: if debug:
@ -848,7 +885,7 @@ def main():
parser.add_argument( parser.add_argument(
'-x', '--explainable', '-x', '--explainable',
action='store_true', action='store_true',
help='Show provenance events: Question, Exploration, Focus, Synthesis (implies streaming)' help='Show provenance events: Question, Grounding, Exploration, Focus, Synthesis (implies streaming)'
) )
parser.add_argument( parser.add_argument(

View file

@ -31,6 +31,8 @@ from ... schema import librarian_request_queue, librarian_response_queue
from trustgraph.provenance import ( from trustgraph.provenance import (
agent_session_uri, agent_session_uri,
agent_iteration_uri, agent_iteration_uri,
agent_thought_uri,
agent_observation_uri,
agent_final_uri, agent_final_uri,
agent_session_triples, agent_session_triples,
agent_iteration_triples, agent_iteration_triples,
@ -624,11 +626,13 @@ class Processor(AgentService):
# Emit final answer provenance triples # Emit final answer provenance triples
final_uri = agent_final_uri(session_id) final_uri = agent_final_uri(session_id)
# Parent is last iteration, or session if no iterations # No iterations: link to question; otherwise: link to last iteration
if iteration_num > 1: if iteration_num > 1:
parent_uri = agent_iteration_uri(session_id, iteration_num - 1) final_question_uri = None
final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
else: else:
parent_uri = session_uri final_question_uri = session_uri
final_previous_uri = None
# Save answer to librarian # Save answer to librarian
answer_doc_id = None answer_doc_id = None
@ -648,8 +652,9 @@ class Processor(AgentService):
final_triples = set_graph( final_triples = set_graph(
agent_final_triples( agent_final_triples(
final_uri, parent_uri, final_uri,
answer="" if answer_doc_id else f, question_uri=final_question_uri,
previous_uri=final_previous_uri,
document_id=answer_doc_id, document_id=answer_doc_id,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
@ -707,11 +712,13 @@ class Processor(AgentService):
# Emit iteration provenance triples # Emit iteration provenance triples
iteration_uri = agent_iteration_uri(session_id, iteration_num) iteration_uri = agent_iteration_uri(session_id, iteration_num)
# Parent is previous iteration, or session if this is first iteration # First iteration links to question, subsequent to previous
if iteration_num > 1: if iteration_num > 1:
parent_uri = agent_iteration_uri(session_id, iteration_num - 1) iter_question_uri = None
iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
else: else:
parent_uri = session_uri iter_question_uri = session_uri
iter_previous_uri = None
# Save thought to librarian # Save thought to librarian
thought_doc_id = None thought_doc_id = None
@ -745,15 +752,19 @@ class Processor(AgentService):
logger.warning(f"Failed to save observation to librarian: {e}") logger.warning(f"Failed to save observation to librarian: {e}")
observation_doc_id = None observation_doc_id = None
thought_entity_uri = agent_thought_uri(session_id, iteration_num)
observation_entity_uri = agent_observation_uri(session_id, iteration_num)
iter_triples = set_graph( iter_triples = set_graph(
agent_iteration_triples( agent_iteration_triples(
iteration_uri, iteration_uri,
parent_uri, question_uri=iter_question_uri,
thought="" if thought_doc_id else act.thought, previous_uri=iter_previous_uri,
action=act.name, action=act.name,
arguments=act.arguments, arguments=act.arguments,
observation="" if observation_doc_id else act.observation, thought_uri=thought_entity_uri if thought_doc_id else None,
thought_document_id=thought_doc_id, thought_document_id=thought_doc_id,
observation_uri=observation_entity_uri if observation_doc_id else None,
observation_document_id=observation_doc_id, observation_document_id=observation_doc_id,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL

View file

@ -7,9 +7,11 @@ from datetime import datetime
# Provenance imports # Provenance imports
from trustgraph.provenance import ( from trustgraph.provenance import (
docrag_question_uri, docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri, docrag_exploration_uri,
docrag_synthesis_uri, docrag_synthesis_uri,
docrag_question_triples, docrag_question_triples,
grounding_triples,
docrag_exploration_triples, docrag_exploration_triples,
docrag_synthesis_triples, docrag_synthesis_triples,
set_graph, set_graph,
@ -33,39 +35,79 @@ class Query:
self.verbose = verbose self.verbose = verbose
self.doc_limit = doc_limit self.doc_limit = doc_limit
async def get_vector(self, query): async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
# Fallback to raw query if no concepts extracted
if not concepts:
concepts = [query]
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
return concepts
async def get_vectors(self, concepts):
"""Compute embeddings for a list of concepts."""
if self.verbose: if self.verbose:
logger.debug("Computing embeddings...") logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed([query]) qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose: if self.verbose:
logger.debug("Embeddings computed") logger.debug("Embeddings computed")
# Return the vector set for the first (only) text return qembeds
return qembeds[0] if qembeds else []
async def get_docs(self, query): async def get_docs(self, concepts):
""" """
Get documents (chunks) matching the query. Get documents (chunks) matching the extracted concepts.
Returns: Returns:
tuple: (docs, chunk_ids) where: tuple: (docs, chunk_ids) where:
- docs: list of document content strings - docs: list of document content strings
- chunk_ids: list of chunk IDs that were successfully fetched - chunk_ids: list of chunk IDs that were successfully fetched
""" """
vectors = await self.get_vector(query) vectors = await self.get_vectors(concepts)
if self.verbose: if self.verbose:
logger.debug("Getting chunks from embeddings store...") logger.debug("Getting chunks from embeddings store...")
# Get chunk matches from embeddings store # Query chunk matches for each concept concurrently
chunk_matches = await self.rag.doc_embeddings_client.query( per_concept_limit = max(
vector=vectors, limit=self.doc_limit, 1, self.doc_limit // len(vectors)
user=self.user, collection=self.collection,
) )
async def query_concept(vec):
return await self.rag.doc_embeddings_client.query(
vector=vec, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
results = await asyncio.gather(
*[query_concept(v) for v in vectors]
)
# Deduplicate chunk matches by chunk_id
seen = set()
chunk_matches = []
for matches in results:
for match in matches:
if match.chunk_id and match.chunk_id not in seen:
seen.add(match.chunk_id)
chunk_matches.append(match)
if self.verbose: if self.verbose:
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...") logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
@ -133,6 +175,7 @@ class DocumentRag:
# Generate explainability URIs upfront # Generate explainability URIs upfront
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
q_uri = docrag_question_uri(session_id) q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id) exp_uri = docrag_exploration_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id) syn_uri = docrag_synthesis_uri(session_id)
@ -151,12 +194,23 @@ class DocumentRag:
doc_limit=doc_limit doc_limit=doc_limit
) )
docs, chunk_ids = await q.get_docs(query) # Extract concepts from query (grounding step)
concepts = await q.extract_concepts(query)
# Emit grounding explainability after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved # Emit exploration explainability after chunks retrieved
if explain_callback: if explain_callback:
exp_triples = set_graph( exp_triples = set_graph(
docrag_exploration_triples(exp_uri, q_uri, len(chunk_ids), chunk_ids), docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
await explain_callback(exp_triples, exp_uri) await explain_callback(exp_triples, exp_uri)
@ -196,9 +250,8 @@ class DocumentRag:
synthesis_doc_id = None synthesis_doc_id = None
answer_text = resp if resp else "" answer_text = resp if resp else ""
# Save answer to librarian if callback provided # Save answer to librarian
if save_answer_callback and answer_text: if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer" synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
try: try:
await save_answer_callback(synthesis_doc_id, answer_text) await save_answer_callback(synthesis_doc_id, answer_text)
@ -206,13 +259,11 @@ class DocumentRag:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e: except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}") logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None # Fall back to inline content synthesis_doc_id = None
# Generate triples with document reference or inline content
syn_triples = set_graph( syn_triples = set_graph(
docrag_synthesis_triples( docrag_synthesis_triples(
syn_uri, exp_uri, syn_uri, exp_uri,
answer_text="" if synthesis_doc_id else answer_text,
document_id=synthesis_doc_id, document_id=synthesis_doc_id,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL

View file

@ -8,20 +8,23 @@ import uuid
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
from ... schema import IRI, LITERAL from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
# Provenance imports # Provenance imports
from trustgraph.provenance import ( from trustgraph.provenance import (
question_uri, question_uri,
grounding_uri as make_grounding_uri,
exploration_uri as make_exploration_uri, exploration_uri as make_exploration_uri,
focus_uri as make_focus_uri, focus_uri as make_focus_uri,
synthesis_uri as make_synthesis_uri, synthesis_uri as make_synthesis_uri,
question_triples, question_triples,
grounding_triples,
exploration_triples, exploration_triples,
focus_triples, focus_triples,
synthesis_triples, synthesis_triples,
set_graph, set_graph,
GRAPH_RETRIEVAL, GRAPH_RETRIEVAL, GRAPH_SOURCE,
TG_CONTAINS, PROV_WAS_DERIVED_FROM,
) )
# Module logger # Module logger
@ -47,6 +50,8 @@ def edge_id(s, p, o):
edge_str = f"{s}|{p}|{o}" edge_str = f"{s}|{p}|{o}"
return hashlib.sha256(edge_str.encode()).hexdigest()[:8] return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
class LRUCacheWithTTL: class LRUCacheWithTTL:
"""LRU cache with TTL for label caching """LRU cache with TTL for label caching
@ -105,42 +110,88 @@ class Query:
self.max_subgraph_size = max_subgraph_size self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length self.max_path_length = max_path_length
async def get_vector(self, query): async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
# Fall back to raw query if extraction returns nothing
return concepts if concepts else [query]
async def get_vectors(self, concepts):
"""Embed multiple concepts concurrently."""
if self.verbose: if self.verbose:
logger.debug("Computing embeddings...") logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed([query]) qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose: if self.verbose:
logger.debug("Done.") logger.debug("Done.")
# Return the vector set for the first (only) text return qembeds
return qembeds[0] if qembeds else []
async def get_entities(self, query): async def get_entities(self, query):
"""
Extract concepts from query, embed them, and retrieve matching entities.
vectors = await self.get_vector(query) Returns:
tuple: (entities, concepts) where entities is a list of entity URI
strings and concepts is the list of concept strings extracted
from the query.
"""
concepts = await self.extract_concepts(query)
vectors = await self.get_vectors(concepts)
if self.verbose: if self.verbose:
logger.debug("Getting entities...") logger.debug("Getting entities...")
entity_matches = await self.rag.graph_embeddings_client.query( # Query entity matches for each concept concurrently
vector=vectors, limit=self.entity_limit, per_concept_limit = max(
user=self.user, collection=self.collection, 1, self.entity_limit // len(vectors)
) )
entities = [ entity_tasks = [
term_to_string(e.entity) self.rag.graph_embeddings_client.query(
for e in entity_matches vector=v, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
for v in vectors
] ]
results = await asyncio.gather(*entity_tasks, return_exceptions=True)
# Deduplicate while preserving order
seen = set()
entities = []
for result in results:
if isinstance(result, Exception) or not result:
continue
for e in result:
entity = term_to_string(e.entity)
if entity not in seen:
seen.add(entity)
entities.append(entity)
if self.verbose: if self.verbose:
logger.debug("Entities:") logger.debug("Entities:")
for ent in entities: for ent in entities:
logger.debug(f" {ent}") logger.debug(f" {ent}")
return entities return entities, concepts
async def maybe_label(self, e): async def maybe_label(self, e):
@ -156,6 +207,7 @@ class Query:
res = await self.rag.triples_client.query( res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1, s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
g="",
) )
if len(res) == 0: if len(res) == 0:
@ -177,19 +229,19 @@ class Query:
s=entity, p=None, o=None, s=entity, p=None, o=None,
limit=limit_per_entity, limit=limit_per_entity,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
batch_size=20, batch_size=20, g="",
), ),
self.rag.triples_client.query_stream( self.rag.triples_client.query_stream(
s=None, p=entity, o=None, s=None, p=entity, o=None,
limit=limit_per_entity, limit=limit_per_entity,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
batch_size=20, batch_size=20, g="",
), ),
self.rag.triples_client.query_stream( self.rag.triples_client.query_stream(
s=None, p=None, o=entity, s=None, p=None, o=entity,
limit=limit_per_entity, limit=limit_per_entity,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
batch_size=20, batch_size=20, g="",
) )
]) ])
@ -262,8 +314,16 @@ class Query:
subgraph.update(batch_result) subgraph.update(batch_result)
async def get_subgraph(self, query): async def get_subgraph(self, query):
"""
Get subgraph by extracting concepts, finding entities, and traversing.
entities = await self.get_entities(query) Returns:
tuple: (subgraph, entities, concepts) where subgraph is a list of
(s, p, o) tuples, entities is the seed entity list, and concepts
is the extracted concept list.
"""
entities, concepts = await self.get_entities(query)
if self.verbose: if self.verbose:
logger.debug("Getting subgraph...") logger.debug("Getting subgraph...")
@ -271,7 +331,7 @@ class Query:
# Use optimized batch traversal instead of sequential processing # Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length) subgraph = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph) return list(subgraph), entities, concepts
async def resolve_labels_batch(self, entities): async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel""" """Resolve labels for multiple entities in parallel"""
@ -286,11 +346,13 @@ class Query:
Get subgraph with labels resolved for display. Get subgraph with labels resolved for display.
Returns: Returns:
tuple: (labeled_edges, uri_map) where: tuple: (labeled_edges, uri_map, entities, concepts) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples - labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o) - uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
- entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query
""" """
subgraph = await self.get_subgraph(query) subgraph, entities, concepts = await self.get_subgraph(query)
# Filter out label triples # Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
@ -338,8 +400,125 @@ class Query:
if self.verbose: if self.verbose:
logger.debug("Done.") logger.debug("Done.")
return labeled_edges, uri_map return labeled_edges, uri_map, entities, concepts
async def trace_source_documents(self, edge_uris):
"""
Trace selected edges back to their source documents via provenance.
Follows the chain: edge subgraph (via tg:contains) chunk
page document (via prov:wasDerivedFrom), all in urn:graph:source.
Args:
edge_uris: List of (s, p, o) URI string tuples
Returns:
List of unique document titles
"""
# Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = []
for s, p, o in edge_uris:
quoted = Term(
type=TRIPLE,
triple=SchemaTriple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o),
)
)
subgraph_tasks.append(
self.rag.triples_client.query(
s=None, p=TG_CONTAINS, o=quoted, limit=1,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
)
subgraph_results = await asyncio.gather(
*subgraph_tasks, return_exceptions=True
)
# Collect unique subgraph URIs
subgraph_uris = set()
for result in subgraph_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
subgraph_uris.add(str(triple.s))
if not subgraph_uris:
return []
# Step 2: Walk prov:wasDerivedFrom chain to find documents
# Each level: query ?entity prov:wasDerivedFrom ?parent
# Stop when we find entities typed tg:Document
current_uris = subgraph_uris
doc_uris = set()
for depth in range(4): # Max depth: subgraph → chunk → page → doc
if not current_uris:
break
derivation_tasks = [
self.rag.triples_client.query(
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
for uri in current_uris
]
derivation_results = await asyncio.gather(
*derivation_tasks, return_exceptions=True
)
# URIs with no parent are root documents
next_uris = set()
for uri, result in zip(current_uris, derivation_results):
if isinstance(result, Exception) or not result:
doc_uris.add(uri)
continue
for triple in result:
next_uris.add(str(triple.o))
current_uris = next_uris - doc_uris
if not doc_uris:
return []
# Step 3: Get all document metadata properties
# Skip structural predicates that aren't useful context
SKIP_PREDICATES = {
PROV_WAS_DERIVED_FROM,
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
}
metadata_tasks = [
self.rag.triples_client.query(
s=uri, p=None, o=None, limit=50,
user=self.user, collection=self.collection,
)
for uri in doc_uris
]
metadata_results = await asyncio.gather(
*metadata_tasks, return_exceptions=True
)
doc_edges = []
for result in metadata_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
p = str(triple.p)
if p in SKIP_PREDICATES:
continue
doc_edges.append((
str(triple.s), p, str(triple.o)
))
return doc_edges
class GraphRag: class GraphRag:
""" """
CRITICAL SECURITY: CRITICAL SECURITY:
@ -371,7 +550,8 @@ class GraphRag:
async def query( async def query(
self, query, user = "trustgraph", collection = "default", self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, streaming = False, chunk_callback = None, max_path_length = 2, edge_limit = 25, streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None, explain_callback = None, save_answer_callback = None,
): ):
""" """
@ -399,6 +579,7 @@ class GraphRag:
# Generate explainability URIs upfront # Generate explainability URIs upfront
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
q_uri = question_uri(session_id) q_uri = question_uri(session_id)
gnd_uri = make_grounding_uri(session_id)
exp_uri = make_exploration_uri(session_id) exp_uri = make_exploration_uri(session_id)
foc_uri = make_focus_uri(session_id) foc_uri = make_focus_uri(session_id)
syn_uri = make_synthesis_uri(session_id) syn_uri = make_synthesis_uri(session_id)
@ -421,12 +602,23 @@ class GraphRag:
max_path_length = max_path_length, max_path_length = max_path_length,
) )
kg, uri_map = await q.get_labelgraph(query) kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Emit grounding explain after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
# Emit exploration explain after graph retrieval completes # Emit exploration explain after graph retrieval completes
if explain_callback: if explain_callback:
exp_triples = set_graph( exp_triples = set_graph(
exploration_triples(exp_uri, q_uri, len(kg)), exploration_triples(
exp_uri, gnd_uri, len(kg),
entities=seed_entities,
),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
await explain_callback(exp_triples, exp_uri) await explain_callback(exp_triples, exp_uri)
@ -453,9 +645,9 @@ class GraphRag:
if self.verbose: if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges") logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1: Edge Selection - LLM selects relevant edges with reasoning # Step 1a: Edge Scoring - LLM scores edges for relevance
selection_response = await self.prompt_client.prompt( scoring_response = await self.prompt_client.prompt(
"kg-edge-selection", "kg-edge-scoring",
variables={ variables={
"query": query, "query": query,
"knowledge": edges_with_ids "knowledge": edges_with_ids
@ -463,52 +655,44 @@ class GraphRag:
) )
if self.verbose: if self.verbose:
logger.debug(f"Edge selection response: {selection_response}") logger.debug(f"Edge scoring response: {scoring_response}")
# Parse response to get selected edge IDs and reasoning # Parse scoring response to get edge IDs with scores
# Response can be a string (JSONL) or a list (JSON array) scored_edges = []
selected_ids = set()
selected_edges_with_reasoning = [] # For explain
if isinstance(selection_response, list): def parse_scored_edge(obj):
# JSON array response if isinstance(obj, dict) and "id" in obj and "score" in obj:
for obj in selection_response: try:
if isinstance(obj, dict) and "id" in obj: score = int(obj["score"])
selected_ids.add(obj["id"]) except (ValueError, TypeError):
# Capture original URI edge (not labels) and reasoning for explain score = 0
eid = obj["id"] scored_edges.append({"id": obj["id"], "score": score})
if eid in uri_map:
# Use original URIs for provenance tracing if isinstance(scoring_response, list):
uri_s, uri_p, uri_o = uri_map[eid] for obj in scoring_response:
selected_edges_with_reasoning.append({ parse_scored_edge(obj)
"edge": (uri_s, uri_p, uri_o), elif isinstance(scoring_response, str):
"reasoning": obj.get("reasoning", ""), for line in scoring_response.strip().split('\n'):
})
elif isinstance(selection_response, str):
# JSONL string response
for line in selection_response.strip().split('\n'):
line = line.strip() line = line.strip()
if not line: if not line:
continue continue
try: try:
obj = json.loads(line) parse_scored_edge(json.loads(line))
if "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse edge selection line: {line}") logger.warning(
continue f"Failed to parse edge scoring line: {line}"
)
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
selected_ids = {e["id"] for e in top_edges}
if self.verbose: if self.verbose:
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}") logger.debug(
f"Scored {len(scored_edges)} edges, "
f"selected top {len(selected_ids)}"
)
# Filter to selected edges # Filter to selected edges
selected_edges = [] selected_edges = []
@ -516,6 +700,82 @@ class GraphRag:
if eid in edge_map: if eid in edge_map:
selected_edges.append(edge_map[eid]) selected_edges.append(edge_map[eid])
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
selected_edges_with_ids = [
{"id": eid, "s": s, "p": p, "o": o}
for eid in selected_ids
if eid in edge_map
for s, p, o in [edge_map[eid]]
]
# Collect selected edge URIs for document tracing
selected_edge_uris = [
uri_map[eid]
for eid in selected_ids
if eid in uri_map
]
# Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
)
# Handle exceptions from gather
if isinstance(reasoning_response, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_response}"
)
reasoning_response = ""
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
)
source_documents = []
if self.verbose:
logger.debug(f"Edge reasoning response: {reasoning_response}")
# Parse reasoning response and build explainability data
reasoning_map = {}
def parse_reasoning(obj):
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
if isinstance(reasoning_response, list):
for obj in reasoning_response:
parse_reasoning(obj)
elif isinstance(reasoning_response, str):
for line in reasoning_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_reasoning(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge reasoning line: {line}"
)
selected_edges_with_reasoning = []
for eid in selected_ids:
if eid in uri_map:
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": reasoning_map.get(eid, ""),
})
if self.verbose: if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges") logger.debug(f"Filtered to {len(selected_edges)} edges")
@ -534,6 +794,18 @@ class GraphRag:
{"s": s, "p": p, "o": o} {"s": s, "p": p, "o": o}
for s, p, o in selected_edges for s, p, o in selected_edges
] ]
# Add source document metadata as knowledge edges
for s, p, o in source_documents:
selected_edge_dicts.append({
"s": s, "p": p, "o": o,
})
synthesis_variables = {
"query": query,
"knowledge": selected_edge_dicts,
}
if streaming and chunk_callback: if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback # Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = [] accumulated_chunks = []
@ -544,10 +816,7 @@ class GraphRag:
await self.prompt_client.prompt( await self.prompt_client.prompt(
"kg-synthesis", "kg-synthesis",
variables={ variables=synthesis_variables,
"query": query,
"knowledge": selected_edge_dicts
},
streaming=True, streaming=True,
chunk_callback=accumulating_callback chunk_callback=accumulating_callback
) )
@ -556,10 +825,7 @@ class GraphRag:
else: else:
resp = await self.prompt_client.prompt( resp = await self.prompt_client.prompt(
"kg-synthesis", "kg-synthesis",
variables={ variables=synthesis_variables,
"query": query,
"knowledge": selected_edge_dicts
}
) )
if self.verbose: if self.verbose:
@ -570,9 +836,8 @@ class GraphRag:
synthesis_doc_id = None synthesis_doc_id = None
answer_text = resp if resp else "" answer_text = resp if resp else ""
# Save answer to librarian if callback provided # Save answer to librarian
if save_answer_callback and answer_text: if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}" synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
try: try:
await save_answer_callback(synthesis_doc_id, answer_text) await save_answer_callback(synthesis_doc_id, answer_text)
@ -580,13 +845,11 @@ class GraphRag:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e: except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}") logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None # Fall back to inline content synthesis_doc_id = None
# Generate triples with document reference or inline content
syn_triples = set_graph( syn_triples = set_graph(
synthesis_triples( synthesis_triples(
syn_uri, foc_uri, syn_uri, foc_uri,
answer_text="" if synthesis_doc_id else answer_text,
document_id=synthesis_doc_id, document_id=synthesis_doc_id,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL

View file

@ -39,6 +39,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30) triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150) max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2) max_path_length = params.get("max_path_length", 2)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
@ -48,6 +49,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit, "triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size, "max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length, "max_path_length": max_path_length,
"edge_limit": edge_limit,
} }
) )
@ -55,6 +57,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length self.default_max_path_length = max_path_length
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections # CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access # Each user/collection combination MUST have isolated data access
@ -292,6 +295,11 @@ class Processor(FlowProcessor):
else: else:
max_path_length = self.default_max_path_length max_path_length = self.default_max_path_length
if v.edge_limit:
edge_limit = v.edge_limit
else:
edge_limit = self.default_edge_limit
# Callback to save answer content to librarian # Callback to save answer content to librarian
async def save_answer(doc_id, answer_text): async def save_answer(doc_id, answer_text):
await self.save_answer_content( await self.save_answer_content(
@ -322,6 +330,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit, entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length, max_path_length = max_path_length,
edge_limit = edge_limit,
streaming = True, streaming = True,
chunk_callback = send_chunk, chunk_callback = send_chunk,
explain_callback = send_explainability, explain_callback = send_explainability,
@ -335,6 +344,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit, entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length, max_path_length = max_path_length,
edge_limit = edge_limit,
explain_callback = send_explainability, explain_callback = send_explainability,
save_answer_callback = save_answer, save_answer_callback = save_answer,
) )