Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding (#697)

Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding,
consistent PROV-O

GraphRAG:
- Split retrieval into 4 prompt stages: extract-concepts,
  kg-edge-scoring,
  kg-edge-reasoning, kg-synthesis (was single-stage)
- Add concept extraction (grounding) for per-concept embedding
- Filter main query to default graph, ignoring
  provenance/explainability edges
- Add source document edges to knowledge graph

DocumentRAG:
- Add grounding step with concept extraction, matching GraphRAG's
  pattern:
  Question → Grounding → Exploration → Synthesis
- Per-concept embedding and chunk retrieval with deduplication

Cross-pipeline:
- Make PROV-O derivation links consistent: wasGeneratedBy for first
  entity from Activity, wasDerivedFrom for entity-to-entity chains
- Update CLIs (tg-invoke-agent, tg-invoke-graph-rag,
  tg-invoke-document-rag) for new explainability structure
- Fix all affected unit and integration tests
This commit is contained in:
cybermaggedon 2026-03-16 12:12:13 +00:00 committed by GitHub
parent 29b4300808
commit a115ec06ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1537 additions and 1008 deletions

View file

@ -86,13 +86,18 @@ class TestGraphRagIntegration:
"""Mock prompt client that generates realistic responses for two-step process"""
client = AsyncMock()
# Mock responses for the two-step process:
# 1. kg-edge-selection returns JSONL with edge IDs
# 2. kg-synthesis returns the final answer
# Mock responses for the multi-step process:
# 1. extract-concepts extracts key concepts from the query
# 2. kg-edge-scoring scores edges for relevance
# 3. kg-edge-reasoning provides reasoning for selected edges
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Return empty selection (no edges selected) - valid JSONL
return ""
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
elif prompt_name == "kg-edge-scoring":
return "" # No edges scored
elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning
elif prompt_name == "kg-synthesis":
return (
"Machine learning is a subset of artificial intelligence that enables computers "
@ -160,16 +165,16 @@ class TestGraphRagIntegration:
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt twice (edge selection + synthesis)
assert mock_prompt_client.prompt.call_count == 2
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()
# Verify provenance was emitted in real-time (4 events: question, exploration, focus, synthesis)
assert len(provenance_events) == 4
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
assert len(provenance_events) == 5
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert prov_id.startswith("urn:trustgraph:")
@ -243,10 +248,10 @@ class TestGraphRagIntegration:
)
# Assert
# Should still call prompt client (twice: edge selection + synthesis)
# Should still call prompt client
assert response is not None
# Provenance should still be emitted (4 events)
assert len(provenance_events) == 4
# Provenance should still be emitted (5 events)
assert len(provenance_events) == 5
@pytest.mark.asyncio
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):

View file

@ -60,8 +60,12 @@ class TestGraphRagStreaming:
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "kg-edge-selection":
# Edge selection returns JSONL with IDs - simulate selecting first edge
if prompt_id == "extract-concepts":
return "" # Falls back to raw query
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n'
elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
@ -132,8 +136,8 @@ class TestGraphRagStreaming:
# Verify content is reasonable
assert "machine" in response.lower() or "learning" in response.lower()
# Verify provenance was emitted in real-time (4 events)
assert len(provenance_events) == 4
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
assert len(provenance_events) == 5
for triples, prov_id in provenance_events:
assert prov_id.startswith("urn:trustgraph:")

View file

@ -15,10 +15,11 @@ from trustgraph.provenance.agent import (
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_AGENT_QUESTION,
)
@ -110,84 +111,107 @@ class TestAgentSessionTriples:
class TestAgentIterationTriples:
ITER_URI = "urn:trustgraph:agent:test-session/i1"
PARENT_URI = "urn:trustgraph:agent:test-session"
SESSION_URI = "urn:trustgraph:agent:test-session"
PREV_URI = "urn:trustgraph:agent:test-session/i0"
def test_iteration_types(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="thinking", action="search", observation="found it",
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
def test_iteration_derived_from_parent(self):
def test_first_iteration_generated_by_question(self):
"""First iteration uses wasGeneratedBy to link to question activity."""
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
# Should NOT have wasDerivedFrom
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is None
def test_subsequent_iteration_derived_from_previous(self):
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
triples = agent_iteration_triples(
self.ITER_URI, previous_uri=self.PREV_URI,
action="search",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
assert derived.o.iri == self.PREV_URI
# Should NOT have wasGeneratedBy
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is None
def test_iteration_label_includes_action(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query",
)
label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
assert label is not None
assert "graph-rag-query" in label.o.value
def test_iteration_thought_inline(self):
def test_iteration_thought_sub_entity(self):
"""Thought is a sub-entity with Reflection and Thought types."""
thought_uri = "urn:trustgraph:agent:test-session/i1/thought"
thought_doc = "urn:doc:thought-1"
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="I need to search for info",
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
thought_uri=thought_uri,
thought_document_id=thought_doc,
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought is not None
assert thought.o.value == "I need to search for info"
# Iteration links to thought sub-entity
thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought_link is not None
assert thought_link.o.iri == thought_uri
# Thought has correct types
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
# Thought was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Thought has document reference
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
assert doc is not None
assert doc.o.iri == thought_doc
def test_iteration_thought_document_preferred(self):
"""When thought_document_id is provided, inline thought is not stored."""
def test_iteration_observation_sub_entity(self):
"""Observation is a sub-entity with Reflection and Observation types."""
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
obs_doc = "urn:doc:obs-1"
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="inline thought",
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
thought_document_id="urn:doc:thought-1",
observation_uri=obs_uri,
observation_document_id=obs_doc,
)
thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI)
assert thought_doc is not None
assert thought_doc.o.iri == "urn:doc:thought-1"
thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought_inline is None
def test_iteration_observation_inline(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
observation="Found 3 results",
)
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs is not None
assert obs.o.value == "Found 3 results"
def test_iteration_observation_document_preferred(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
observation="inline obs",
observation_document_id="urn:doc:obs-1",
)
obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI)
assert obs_doc is not None
assert obs_doc.o.iri == "urn:doc:obs-1"
obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_inline is None
# Iteration links to observation sub-entity
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_link is not None
assert obs_link.o.iri == obs_uri
# Observation has correct types
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
# Observation was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Observation has document reference
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
assert doc is not None
assert doc.o.iri == obs_doc
def test_iteration_action_recorded(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query",
)
action = find_triple(triples, TG_ACTION, self.ITER_URI)
@ -197,7 +221,7 @@ class TestAgentIterationTriples:
def test_iteration_arguments_json_encoded(self):
args = {"query": "test query", "limit": 10}
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
arguments=args,
)
@ -208,7 +232,7 @@ class TestAgentIterationTriples:
def test_iteration_default_arguments_empty_dict(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
@ -219,7 +243,7 @@ class TestAgentIterationTriples:
def test_iteration_no_thought_or_observation(self):
"""Minimal iteration with just action — no thought or observation triples."""
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="noop",
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
@ -228,19 +252,19 @@ class TestAgentIterationTriples:
assert obs is None
def test_iteration_chaining(self):
"""Second iteration derives from first iteration, not session."""
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
iter1_uri = "urn:trustgraph:agent:sess/i1"
iter2_uri = "urn:trustgraph:agent:sess/i2"
triples1 = agent_iteration_triples(
iter1_uri, self.PARENT_URI, action="step1",
iter1_uri, question_uri=self.SESSION_URI, action="step1",
)
triples2 = agent_iteration_triples(
iter2_uri, iter1_uri, action="step2",
iter2_uri, previous_uri=iter1_uri, action="step2",
)
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
assert derived1.o.iri == self.PARENT_URI
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
assert gen1.o.iri == self.SESSION_URI
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
assert derived2.o.iri == iter1_uri
@ -253,42 +277,50 @@ class TestAgentIterationTriples:
class TestAgentFinalTriples:
FINAL_URI = "urn:trustgraph:agent:test-session/final"
PARENT_URI = "urn:trustgraph:agent:test-session/i3"
PREV_URI = "urn:trustgraph:agent:test-session/i3"
SESSION_URI = "urn:trustgraph:agent:test-session"
def test_final_types(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
self.FINAL_URI, previous_uri=self.PREV_URI,
)
assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE)
def test_final_derived_from_parent(self):
def test_final_derived_from_previous(self):
"""Conclusion with iterations uses wasDerivedFrom."""
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
self.FINAL_URI, previous_uri=self.PREV_URI,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
assert derived.o.iri == self.PREV_URI
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is None
def test_final_generated_by_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasGeneratedBy."""
triples = agent_final_triples(
self.FINAL_URI, question_uri=self.SESSION_URI,
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is None
def test_final_label(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
self.FINAL_URI, previous_uri=self.PREV_URI,
)
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
assert label is not None
assert label.o.value == "Conclusion"
def test_final_inline_answer(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="The answer is 42"
)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is not None
assert answer.o.value == "The answer is 42"
def test_final_document_reference(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI,
self.FINAL_URI, previous_uri=self.PREV_URI,
document_id="urn:trustgraph:agent:sess/answer",
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
@ -296,29 +328,9 @@ class TestAgentFinalTriples:
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
def test_final_document_takes_precedence(self):
def test_final_no_document(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI,
answer="inline",
document_id="urn:doc:123",
self.FINAL_URI, previous_uri=self.PREV_URI,
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is not None
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is None
def test_final_no_answer_or_document(self):
triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert answer is None
assert doc is None
def test_final_derives_from_session_when_no_iterations(self):
"""When agent answers immediately, final derives from session."""
session_uri = "urn:trustgraph:agent:test-session"
triples = agent_final_triples(
self.FINAL_URI, session_uri, answer="direct answer"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived.o.iri == session_uri

View file

@ -10,9 +10,11 @@ from trustgraph.api.explainability import (
EdgeSelection,
ExplainEntity,
Question,
Grounding,
Exploration,
Focus,
Synthesis,
Reflection,
Analysis,
Conclusion,
parse_edge_selection_triples,
@ -20,11 +22,11 @@ from trustgraph.api.explainability import (
wire_triples_to_tuples,
ExplainabilityClient,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
RDF_TYPE, RDFS_LABEL,
@ -71,6 +73,18 @@ class TestExplainEntityFromTriples:
assert isinstance(entity, Question)
assert entity.question_type == "agent"
def test_grounding(self):
triples = [
("urn:gnd:1", RDF_TYPE, TG_GROUNDING),
("urn:gnd:1", TG_CONCEPT, "machine learning"),
("urn:gnd:1", TG_CONCEPT, "neural networks"),
]
entity = ExplainEntity.from_triples("urn:gnd:1", triples)
assert isinstance(entity, Grounding)
assert len(entity.concepts) == 2
assert "machine learning" in entity.concepts
assert "neural networks" in entity.concepts
def test_exploration(self):
triples = [
("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
@ -89,6 +103,17 @@ class TestExplainEntityFromTriples:
assert isinstance(entity, Exploration)
assert entity.chunk_count == 5
def test_exploration_with_entities(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
("urn:exp:3", TG_EDGE_COUNT, "10"),
("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"),
("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"),
]
entity = ExplainEntity.from_triples("urn:exp:3", triples)
assert isinstance(entity, Exploration)
assert len(entity.entities) == 2
def test_exploration_invalid_count(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
@ -110,69 +135,76 @@ class TestExplainEntityFromTriples:
assert "urn:edge:1" in entity.selected_edge_uris
assert "urn:edge:2" in entity.selected_edge_uris
def test_synthesis_with_content(self):
def test_synthesis_with_document(self):
triples = [
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:1", TG_CONTENT, "The answer is 42"),
("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"),
]
entity = ExplainEntity.from_triples("urn:syn:1", triples)
assert isinstance(entity, Synthesis)
assert entity.content == "The answer is 42"
assert entity.document_uri == ""
assert entity.document_uri == "urn:doc:answer-1"
def test_synthesis_with_document(self):
def test_synthesis_no_document(self):
triples = [
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"),
]
entity = ExplainEntity.from_triples("urn:syn:2", triples)
assert isinstance(entity, Synthesis)
assert entity.document_uri == "urn:doc:answer-1"
assert entity.document_uri == ""
def test_reflection_thought(self):
triples = [
("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE),
("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"),
]
entity = ExplainEntity.from_triples("urn:ref:1", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "thought"
assert entity.document_uri == "urn:doc:thought-1"
def test_reflection_observation(self):
triples = [
("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ref:2", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "observation"
assert entity.document_uri == "urn:doc:obs-1"
def test_analysis(self):
triples = [
("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
("urn:ana:1", TG_THOUGHT, "I should search"),
("urn:ana:1", TG_ACTION, "graph-rag-query"),
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
("urn:ana:1", TG_OBSERVATION, "Found results"),
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:1", triples)
assert isinstance(entity, Analysis)
assert entity.thought == "I should search"
assert entity.action == "graph-rag-query"
assert entity.arguments == '{"query": "test"}'
assert entity.observation == "Found results"
def test_analysis_with_document_refs(self):
triples = [
("urn:ana:2", RDF_TYPE, TG_ANALYSIS),
("urn:ana:2", TG_ACTION, "search"),
("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"),
("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:2", triples)
assert isinstance(entity, Analysis)
assert entity.thought_document_uri == "urn:doc:thought-1"
assert entity.observation_document_uri == "urn:doc:obs-1"
def test_conclusion_with_answer(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_ANSWER, "The final answer"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.answer == "The final answer"
assert entity.thought_uri == "urn:ref:thought-1"
assert entity.observation_uri == "urn:ref:obs-1"
def test_conclusion_with_document(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_DOCUMENT, "urn:doc:final"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.document_uri == "urn:doc:final"
def test_conclusion_no_document(self):
triples = [
("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
("urn:conc:2", TG_DOCUMENT, "urn:doc:final"),
]
entity = ExplainEntity.from_triples("urn:conc:2", triples)
assert isinstance(entity, Conclusion)
assert entity.document_uri == "urn:doc:final"
assert entity.document_uri == ""
def test_unknown_type(self):
triples = [
@ -457,25 +489,7 @@ class TestExplainabilityClientResolveLabel:
class TestExplainabilityClientContentFetching:
def test_fetch_synthesis_inline_content(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1", content="inline answer")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == "inline answer"
def test_fetch_synthesis_truncated_content(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
long_content = "x" * 20000
synthesis = Synthesis(uri="urn:syn:1", content=long_content)
result = client.fetch_synthesis_content(synthesis, api=None, max_content=100)
assert len(result) < 20000
assert result.endswith("... [truncated]")
def test_fetch_synthesis_from_librarian(self):
def test_fetch_document_content_from_librarian(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
@ -483,66 +497,32 @@ class TestExplainabilityClientContentFetching:
mock_library.get_document_content.return_value = b"librarian content"
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(
uri="urn:syn:1",
document_uri="urn:document:abc123"
result = client.fetch_document_content(
"urn:document:abc123", api=mock_api
)
result = client.fetch_synthesis_content(synthesis, api=mock_api)
assert result == "librarian content"
def test_fetch_synthesis_no_content_or_document(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == ""
def test_fetch_conclusion_inline(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
conclusion = Conclusion(uri="urn:conc:1", answer="42")
result = client.fetch_conclusion_content(conclusion, api=None)
assert result == "42"
def test_fetch_analysis_content_from_librarian(self):
def test_fetch_document_content_truncated(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
mock_api.library.return_value = mock_library
mock_library.get_document_content.side_effect = [
b"thought content",
b"observation content",
]
mock_library.get_document_content.return_value = b"x" * 20000
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis(
uri="urn:ana:1",
action="search",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
result = client.fetch_document_content(
"urn:doc:1", api=mock_api, max_content=100
)
client.fetch_analysis_content(analysis, api=mock_api)
assert analysis.thought == "thought content"
assert analysis.observation == "observation content"
assert len(result) < 20000
assert result.endswith("... [truncated]")
def test_fetch_analysis_skips_when_inline_exists(self):
def test_fetch_document_content_empty_uri(self):
mock_flow = MagicMock()
mock_api = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis(
uri="urn:ana:1",
action="search",
thought="already have thought",
observation="already have observation",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
)
client.fetch_analysis_content(analysis, api=mock_api)
# Should not call librarian since inline content exists
mock_api.library.assert_not_called()
result = client.fetch_document_content("", api=mock_api)
assert result == ""
class TestExplainabilityClientDetectSessionType:

View file

@ -13,6 +13,7 @@ from trustgraph.provenance.triples import (
derived_entity_triples,
subgraph_provenance_triples,
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
@ -32,10 +33,12 @@ from trustgraph.provenance.namespaces import (
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_CONTENT, TG_DOCUMENT,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
GRAPH_SOURCE, GRAPH_RETRIEVAL,
)
@ -530,36 +533,77 @@ class TestQuestionTriples:
assert len(triples) == 6
class TestExplorationTriples:
class TestGroundingTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
GND_URI = "urn:trustgraph:prov:grounding:test-session"
Q_URI = "urn:trustgraph:question:test-session"
def test_exploration_types(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_grounding_types(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"])
assert has_type(triples, self.GND_URI, PROV_ENTITY)
assert has_type(triples, self.GND_URI, TG_GROUNDING)
def test_exploration_generated_by_question(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
def test_grounding_generated_by_question(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
assert gen is not None
assert gen.o.iri == self.Q_URI
def test_grounding_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 3
values = {t.o.value for t in concepts}
assert values == {"AI", "ML", "robots"}
def test_grounding_empty_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 0
def test_grounding_label(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
label = find_triple(triples, RDFS_LABEL, self.GND_URI)
assert label is not None
assert label.o.value == "Grounding"
class TestExplorationTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
GND_URI = "urn:trustgraph:prov:grounding:test-session"
def test_exploration_types(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_exploration_derived_from_grounding(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert derived is not None
assert derived.o.iri == self.GND_URI
def test_exploration_edge_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "15"
def test_exploration_zero_edges(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 0)
triples = exploration_triples(self.EXP_URI, self.GND_URI, 0)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "0"
def test_exploration_with_entities(self):
entities = ["urn:e:machine-learning", "urn:e:neural-networks"]
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities)
ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI)
assert len(ent_triples) == 2
def test_exploration_triple_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 10)
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10)
assert len(triples) == 5
@ -652,6 +696,7 @@ class TestSynthesisTriples:
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
def test_synthesis_derived_from_focus(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
@ -659,12 +704,6 @@ class TestSynthesisTriples:
assert derived is not None
assert derived.o.iri == self.FOC_URI
def test_synthesis_with_inline_content(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42")
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is not None
assert content.o.value == "The answer is 42"
def test_synthesis_with_document_reference(self):
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
@ -675,23 +714,9 @@ class TestSynthesisTriples:
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:question:abc/answer"
def test_synthesis_document_takes_precedence(self):
"""When both document_id and answer_text are provided, document_id wins."""
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
answer_text="inline",
document_id="urn:doc:123",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is not None
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None
def test_synthesis_no_content_or_document(self):
def test_synthesis_no_document(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert content is None
assert doc is None
@ -723,31 +748,31 @@ class TestDocRagQuestionTriples:
class TestDocRagExplorationTriples:
EXP_URI = "urn:trustgraph:docrag:test/exploration"
Q_URI = "urn:trustgraph:docrag:test"
GND_URI = "urn:trustgraph:docrag:test/grounding"
def test_docrag_exploration_types(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_docrag_exploration_generated_by(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
assert gen.o.iri == self.Q_URI
def test_docrag_exploration_derived_from_grounding(self):
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert derived.o.iri == self.GND_URI
def test_docrag_exploration_chunk_count(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7)
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
assert cc.o.value == "7"
def test_docrag_exploration_without_chunk_ids(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3)
chunks = find_triples(triples, TG_SELECTED_CHUNK)
assert len(chunks) == 0
def test_docrag_exploration_with_chunk_ids(self):
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids)
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
assert len(chunks) == 3
chunk_uris = {t.o.iri for t in chunks}
@ -770,10 +795,9 @@ class TestDocRagSynthesisTriples:
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
assert derived.o.iri == self.EXP_URI
def test_docrag_synthesis_with_inline(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer")
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content.o.value == "answer"
def test_docrag_synthesis_has_answer_type(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
def test_docrag_synthesis_with_document(self):
triples = docrag_synthesis_triples(
@ -781,5 +805,8 @@ class TestDocRagSynthesisTriples:
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc.o.iri == "urn:doc:ans"
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None
def test_docrag_synthesis_no_document(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is None

View file

@ -125,19 +125,15 @@ class TestQuery:
assert query.doc_limit == 50
@pytest.mark.asyncio
async def test_get_vector_method(self):
"""Test Query.get_vector method calls embeddings client correctly"""
# Create mock DocumentRag with embeddings client
async def test_extract_concepts(self):
"""Test Query.extract_concepts extracts concepts from query"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock the embed method to return test vectors in batch format
# New format: [[[vectors_for_text1]]] - returns first text's vector set
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -145,20 +141,62 @@ class TestQuery:
verbose=False
)
# Call get_vector
test_query = "What documents are relevant?"
result = await query.get_vector(test_query)
result = await query.extract_concepts("What is machine learning?")
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "artificial intelligence", "data patterns"]
# Verify result matches expected vectors (extracted from batch)
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test Query.extract_concepts falls back to raw query when no concepts extracted"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is ML?")
assert result == ["What is ML?"]
@pytest.mark.asyncio
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method - returns vectors for each concept
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
concepts = ["machine learning", "data patterns"]
result = await query.get_vectors(concepts)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_docs_method(self):
"""Test Query.get_docs method retrieves documents correctly"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
@ -170,10 +208,8 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock the embedding and document query responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock embeddings - one vector per concept
mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]]
# Mock document embeddings returns ChunkMatch objects
mock_match1 = MagicMock()
@ -184,7 +220,6 @@ class TestQuery:
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -193,16 +228,16 @@ class TestQuery:
doc_limit=15
)
# Call get_docs
test_query = "Find relevant documents"
result = await query.get_docs(test_query)
# Call get_docs with concepts list
concepts = ["test concept"]
result = await query.get_docs(concepts)
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify embeddings client was called with concepts
mock_embeddings_client.embed.assert_called_once_with(concepts)
# Verify doc embeddings client was called correctly (with extracted vector)
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
vector=[0.1, 0.2, 0.3],
limit=15,
user="test_user",
collection="test_collection"
@ -218,14 +253,17 @@ class TestQuery:
@pytest.mark.asyncio
async def test_document_rag_query_method(self, mock_fetch_chunk):
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
@ -234,11 +272,9 @@ class TestQuery:
mock_match2.score = 0.8
expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -247,7 +283,6 @@ class TestQuery:
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query(
query="test query",
user="test_user",
@ -255,12 +290,18 @@ class TestQuery:
doc_limit=10
)
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify concept extraction was called
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "test query"}
)
# Verify doc embeddings client was called (with extracted vector)
# Verify embeddings called with extracted concepts
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
@ -270,23 +311,23 @@ class TestQuery:
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "test query"
# Documents should be fetched content, not chunk_ids
docs = call_args.kwargs["documents"]
assert "Relevant document content" in docs
assert "Another document" in docs
# Verify result
assert result == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag.query method with default parameters"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format)
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
@ -294,7 +335,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -302,10 +342,9 @@ class TestQuery:
fetch_chunk=mock_fetch_chunk
)
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query")
# Verify default parameters were used (vector extracted from batch)
# Verify default parameters were used
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
@ -318,7 +357,6 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
"""Test Query.get_docs method with verbose logging"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
@ -330,14 +368,13 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock responses (batch format)
# Mock responses - one vector per concept
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
user="test_user",
@ -346,14 +383,12 @@ class TestQuery:
doc_limit=5
)
# Call get_docs
result = await query.get_docs("verbose test")
# Call get_docs with concepts
result = await query.get_docs(["verbose test"])
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify result is tuple of (docs, chunk_ids) with fetched content
docs, chunk_ids = result
assert "Verbose test doc" in docs
assert "doc/c6" in chunk_ids
@ -361,12 +396,14 @@ class TestQuery:
@pytest.mark.asyncio
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag.query method with verbose logging enabled"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format)
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
@ -374,7 +411,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -383,14 +419,11 @@ class TestQuery:
verbose=True
)
# Call DocumentRag.query
result = await document_rag.query("verbose query test")
# Verify all clients were called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
# Verify prompt client was called with fetched content
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
@ -400,23 +433,20 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
"""Test Query.get_docs method when no documents are found"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk (won't be called if no chunk_ids)
async def mock_fetch(chunk_id, user):
return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty chunk_id list (batch format)
# Mock responses - empty results
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_doc_embeddings_client.query.return_value = []
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -424,30 +454,27 @@ class TestQuery:
verbose=False
)
# Call get_docs
result = await query.get_docs("query with no results")
result = await query.get_docs(["query with no results"])
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify empty result is returned (tuple of empty lists)
assert result == ([], [])
@pytest.mark.asyncio
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
"""Test DocumentRag.query method when no documents are retrieved"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses - no chunk_ids found (batch format)
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -456,10 +483,8 @@ class TestQuery:
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query("query with no matching docs")
# Verify prompt client was called with empty document list
mock_prompt_client.document_prompt.assert_called_once_with(
query="query with no matching docs",
documents=[]
@ -468,18 +493,15 @@ class TestQuery:
assert result == "No documents found response"
@pytest.mark.asyncio
async def test_get_vector_with_verbose(self):
"""Test Query.get_vector method with verbose logging"""
# Create mock DocumentRag with embeddings client
async def test_get_vectors_with_verbose(self):
"""Test Query.get_vectors method with verbose logging"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method (batch format)
expected_vectors = [[0.9, 1.0, 1.1]]
mock_embeddings_client.embed.return_value = [expected_vectors]
mock_embeddings_client.embed.return_value = expected_vectors
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
user="test_user",
@ -487,40 +509,40 @@ class TestQuery:
verbose=True
)
# Call get_vector
result = await query.get_vector("verbose vector test")
result = await query.get_vectors(["verbose vector test"])
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
# Verify result (extracted from batch)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
"""Test complete DocumentRag integration with realistic data flow"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock realistic responses (batch format)
query_text = "What is machine learning?"
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors]
mock_matches = []
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock()
mock_match.chunk_id = chunk_id
mock_match.score = 0.9
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
mock_embeddings_client.embed.return_value = query_vectors
# Each concept query returns some matches
mock_matches_1 = [
MagicMock(chunk_id="doc/ml1", score=0.9),
MagicMock(chunk_id="doc/ml2", score=0.85),
]
mock_matches_2 = [
MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -529,7 +551,6 @@ class TestQuery:
verbose=False
)
# Execute full pipeline
result = await document_rag.query(
query=query_text,
user="research_user",
@ -537,26 +558,69 @@ class TestQuery:
doc_limit=25
)
# Verify complete pipeline execution (now expects list)
mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with(
vector=query_vectors,
limit=25,
user="research_user",
collection="ml_knowledge"
# Verify concept extraction
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": query_text}
)
# Verify embeddings called with concepts
mock_embeddings_client.embed.assert_called_once_with(
["machine learning", "artificial intelligence"]
)
# Verify two per-concept queries were made (25 // 2 = 12 per concept)
assert mock_doc_embeddings_client.query.call_count == 2
# Verify prompt client was called with fetched document content
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == query_text
# Verify documents were fetched from chunk_ids
# Verify documents were fetched and deduplicated
docs = call_args.kwargs["documents"]
assert "Machine learning is a subset of artificial intelligence..." in docs
assert "ML algorithms learn patterns from data to make predictions..." in docs
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
# Verify final result
assert result == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):
"""Test that get_docs deduplicates chunks across multiple concepts"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Two concepts → two vectors
mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Both queries return overlapping chunks
match_a = MagicMock(chunk_id="doc/c1", score=0.9)
match_b = MagicMock(chunk_id="doc/c2", score=0.8)
match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate
mock_doc_embeddings_client.query.side_effect = [
[match_a, match_b],
[match_c],
]
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
doc_limit=10
)
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
assert len(chunk_ids) == 2 # doc/c1 only counted once
assert "doc/c1" in chunk_ids
assert "doc/c2" in chunk_ids

View file

@ -19,7 +19,7 @@ class TestGraphRag:
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
@ -27,7 +27,7 @@ class TestGraphRag:
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
@ -45,7 +45,7 @@ class TestGraphRag:
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag with verbose=True
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
@ -54,7 +54,7 @@ class TestGraphRag:
triples_client=mock_triples_client,
verbose=True
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
@ -73,7 +73,7 @@ class TestQuery:
"""Test Query initialization with default parameters"""
# Create mock GraphRag
mock_rag = MagicMock()
# Initialize Query with defaults
query = Query(
rag=mock_rag,
@ -81,7 +81,7 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
@ -96,7 +96,7 @@ class TestQuery:
"""Test Query initialization with custom parameters"""
# Create mock GraphRag
mock_rag = MagicMock()
# Initialize Query with custom parameters
query = Query(
rag=mock_rag,
@ -108,7 +108,7 @@ class TestQuery:
max_subgraph_size=2000,
max_path_length=3
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
@ -120,18 +120,16 @@ class TestQuery:
assert query.max_path_length == 3
@pytest.mark.asyncio
async def test_get_vector_method(self):
"""Test Query.get_vector method calls embeddings client correctly"""
# Create mock GraphRag with embeddings client
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method to return test vectors (batch format)
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query
# Mock embed to return vectors for a list of concepts
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
@ -139,29 +137,22 @@ class TestQuery:
verbose=False
)
# Call get_vector
test_query = "What is the capital of France?"
result = await query.get_vector(test_query)
concepts = ["machine learning", "neural networks"]
result = await query.get_vectors(concepts)
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_vector_method_with_verbose(self):
"""Test Query.get_vector method with verbose output"""
# Create mock GraphRag with embeddings client
async def test_get_vectors_method_with_verbose(self):
"""Test Query.get_vectors method with verbose output"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method (batch format)
expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query with verbose=True
expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
@ -169,48 +160,87 @@ class TestQuery:
verbose=True
)
# Call get_vector
test_query = "Test query for embeddings"
result = await query.get_vector(test_query)
result = await query.get_vectors(["test concept"])
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_entities_method(self):
"""Test Query.get_entities method retrieves entities correctly"""
# Create mock GraphRag with clients
async def test_extract_concepts(self):
"""Test Query.extract_concepts parses LLM response into concept list"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is machine learning?")
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "neural networks"]
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test extract_concepts falls back to raw query when LLM returns empty"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("test query")
assert result == ["test query"]
@pytest.mark.asyncio
async def test_get_entities_method(self):
"""Test Query.get_entities extracts concepts, embeds, and retrieves entities"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# Mock the embedding and entity query responses (batch format)
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock EntityMatch objects with entity as Term-like object
# extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = ""
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
# Mock entity matches
mock_entity1 = MagicMock()
mock_entity1.type = "i" # IRI type
mock_entity1.type = "i"
mock_entity1.iri = "entity1"
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock()
mock_entity2.type = "i" # IRI type
mock_entity2.type = "i"
mock_entity2.iri = "entity2"
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -219,35 +249,23 @@ class TestQuery:
entity_limit=25
)
# Call get_entities
test_query = "Find related entities"
result = await query.get_entities(test_query)
entities, concepts = await query.get_entities("Find related entities")
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify embeddings client was called with the fallback concept
mock_embeddings_client.embed.assert_called_once_with(["Find related entities"])
# Verify graph embeddings client was called correctly (with extracted vector)
mock_graph_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
limit=25,
user="test_user",
collection="test_collection"
)
# Verify result is list of entity strings
assert result == ["entity1", "entity2"]
# Verify result
assert entities == ["entity1", "entity2"]
assert concepts == ["Find related entities"]
@pytest.mark.asyncio
async def test_maybe_label_with_cached_label(self):
"""Test Query.maybe_label method with cached label"""
# Create mock GraphRag with label cache
mock_rag = MagicMock()
# Create mock LRUCacheWithTTL
mock_cache = MagicMock()
mock_cache.get.return_value = "Entity One Label"
mock_rag.label_cache = mock_cache
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -255,32 +273,25 @@ class TestQuery:
verbose=False
)
# Call maybe_label with cached entity
result = await query.maybe_label("entity1")
# Verify cached label is returned
assert result == "Entity One Label"
# Verify cache was checked with proper key format (user:collection:entity)
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
"""Test Query.maybe_label method with database label lookup"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock triple result with label
mock_triple = MagicMock()
mock_triple.o = "Human Readable Label"
mock_triples_client.query.return_value = [mock_triple]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -288,20 +299,18 @@ class TestQuery:
verbose=False
)
# Call maybe_label
result = await query.maybe_label("http://example.com/entity")
# Verify triples client was called correctly
mock_triples_client.query.assert_called_once_with(
s="http://example.com/entity",
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection"
collection="test_collection",
g=""
)
# Verify result and cache update with proper key
assert result == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@ -309,40 +318,34 @@ class TestQuery:
@pytest.mark.asyncio
async def test_maybe_label_with_no_label_found(self):
"""Test Query.maybe_label method when no label is found"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock empty result (no label found)
mock_triples_client.query.return_value = []
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Call maybe_label
result = await query.maybe_label("unlabeled_entity")
# Verify triples client was called
mock_triples_client.query.assert_called_once_with(
s="unlabeled_entity",
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection"
collection="test_collection",
g=""
)
# Verify result is entity itself and cache is updated
assert result == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@ -350,29 +353,25 @@ class TestQuery:
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock triple results for different query patterns
mock_triple1 = MagicMock()
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
mock_triple2 = MagicMock()
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent, p=None, o=None
[mock_triple2], # s=None, p=ent, o=None
[mock_triple3], # s=None, p=None, o=ent
[mock_triple1], # s=ent
[mock_triple2], # p=ent
[mock_triple3], # o=ent
]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -380,29 +379,25 @@ class TestQuery:
verbose=False,
triple_limit=10
)
# Call follow_edges
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify all three query patterns were called
assert mock_triples_client.query_stream.call_count == 3
# Verify query_stream calls
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20
user="test_user", collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20
user="test_user", collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection", batch_size=20
user="test_user", collection="test_collection", batch_size=20, g=""
)
# Verify subgraph contains discovered triples
expected_subgraph = {
("entity1", "predicate1", "object1"),
("subject2", "entity1", "object2"),
@ -413,38 +408,30 @@ class TestQuery:
@pytest.mark.asyncio
async def test_follow_edges_with_path_length_zero(self):
"""Test Query.follow_edges method with path_length=0"""
# Create mock GraphRag
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Call follow_edges with path_length=0
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
# Verify no queries were made
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph remains empty
assert subgraph == set()
@pytest.mark.asyncio
async def test_follow_edges_with_max_subgraph_size_limit(self):
"""Test Query.follow_edges method respects max_subgraph_size"""
# Create mock GraphRag
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Initialize Query with small max_subgraph_size
query = Query(
rag=mock_rag,
user="test_user",
@ -452,23 +439,17 @@ class TestQuery:
verbose=False,
max_subgraph_size=2
)
# Pre-populate subgraph to exceed limit
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
# Call follow_edges
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify no queries were made due to size limit
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph unchanged
assert len(subgraph) == 3
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
# Create mock Query that patches get_entities and follow_edges_batch
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
mock_rag = MagicMock()
query = Query(
@ -479,130 +460,119 @@ class TestQuery:
max_path_length=1
)
# Mock get_entities to return test entities
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
# Mock get_entities to return (entities, concepts) tuple
query.get_entities = AsyncMock(
return_value=(["entity1", "entity2"], ["concept1"])
)
# Mock follow_edges_batch to return test triples
query.follow_edges_batch = AsyncMock(return_value={
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
})
# Call get_subgraph
result = await query.get_subgraph("test query")
subgraph, entities, concepts = await query.get_subgraph("test query")
# Verify get_entities was called
query.get_entities.assert_called_once_with("test query")
# Verify follow_edges_batch was called with entities and max_path_length
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
# Verify result is list format and contains expected triples
assert isinstance(result, list)
assert len(result) == 2
assert ("entity1", "predicate1", "object1") in result
assert ("entity2", "predicate2", "object2") in result
assert isinstance(subgraph, list)
assert len(subgraph) == 2
assert ("entity1", "predicate1", "object1") in subgraph
assert ("entity2", "predicate2", "object2") in subgraph
assert entities == ["entity1", "entity2"]
assert concepts == ["concept1"]
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):
"""Test Query.get_labelgraph method converts entities to labels"""
# Create mock Query
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
collection="test_collection",
verbose=False,
max_subgraph_size=100
)
# Mock get_subgraph to return test triples
test_subgraph = [
("entity1", "predicate1", "object1"),
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
("entity3", "predicate3", "object3")
]
query.get_subgraph = AsyncMock(return_value=test_subgraph)
# Mock maybe_label to return human-readable labels
test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
query.get_subgraph = AsyncMock(
return_value=(test_subgraph, test_entities, test_concepts)
)
async def mock_maybe_label(entity):
label_map = {
"entity1": "Human Entity One",
"predicate1": "Human Predicate One",
"predicate1": "Human Predicate One",
"object1": "Human Object One",
"entity3": "Human Entity Three",
"predicate3": "Human Predicate Three",
"object3": "Human Object Three"
}
return label_map.get(entity, entity)
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
# Call get_labelgraph
labeled_edges, uri_map = await query.get_labelgraph("test query")
# Verify get_subgraph was called
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
query.get_subgraph.assert_called_once_with("test query")
# Verify label triples are filtered out
assert len(labeled_edges) == 2 # Label triple should be excluded
# Label triples filtered out
assert len(labeled_edges) == 2
# Verify maybe_label was called for non-label triples
expected_calls = [
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
]
# maybe_label called for non-label triples
assert query.maybe_label.call_count == 6
# Verify result contains human-readable labels
expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three")
]
assert labeled_edges == expected_edges
# Verify uri_map maps labeled edges back to original URIs
assert len(uri_map) == 2
assert entities == test_entities
assert concepts == test_concepts
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance"""
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
# Mock prompt client responses for two-step process
expected_response = "This is the RAG response"
test_labelgraph = [("Subject", "Predicate", "Object")]
# Compute the edge ID for the test edge
test_edge_id = edge_id("Subject", "Predicate", "Object")
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
test_entities = ["http://example.org/subject"]
test_concepts = ["test concept"]
# Mock edge selection response (JSONL format)
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
# Configure prompt mock to return different responses based on prompt name
# Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return edge_selection_response
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9})
elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
elif prompt_name == "kg-synthesis":
return expected_response
return ""
mock_prompt_client.prompt = mock_prompt
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -611,27 +581,20 @@ class TestQuery:
verbose=False
)
# We need to patch the Query class's get_labelgraph method
original_query_init = Query.__init__
# Patch Query.get_labelgraph to return test data
original_get_labelgraph = Query.get_labelgraph
def mock_query_init(self, *args, **kwargs):
original_query_init(self, *args, **kwargs)
async def mock_get_labelgraph(self, query_text):
return test_labelgraph, test_uri_map
return test_labelgraph, test_uri_map, test_entities, test_concepts
Query.__init__ = mock_query_init
Query.get_labelgraph = mock_get_labelgraph
# Collect provenance emitted via callback
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
try:
# Call GraphRag.query with provenance callback
response = await graph_rag.query(
query="test query",
user="test_user",
@ -641,25 +604,22 @@ class TestQuery:
explain_callback=collect_provenance
)
# Verify response text
assert response == expected_response
# Verify provenance was emitted incrementally (4 events: question, exploration, focus, synthesis)
assert len(provenance_events) == 4
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5
# Verify each event has triples and a URN
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
# Verify order: question, exploration, focus, synthesis
# Verify order
assert "question" in provenance_events[0][1]
assert "exploration" in provenance_events[1][1]
assert "focus" in provenance_events[2][1]
assert "synthesis" in provenance_events[3][1]
assert "grounding" in provenance_events[1][1]
assert "exploration" in provenance_events[2][1]
assert "focus" in provenance_events[3][1]
assert "synthesis" in provenance_events[4][1]
finally:
# Restore original methods
Query.__init__ = original_query_init
Query.get_labelgraph = original_get_labelgraph
Query.get_labelgraph = original_get_labelgraph

View file

@ -75,9 +75,11 @@ from .explainability import (
ExplainabilityClient,
ExplainEntity,
Question,
Grounding,
Exploration,
Focus,
Synthesis,
Reflection,
Analysis,
Conclusion,
EdgeSelection,

View file

@ -18,25 +18,28 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk"
TG_THOUGHT = TG + "thought"
TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
# Entity types
TG_QUESTION = TG + "Question"
TG_GROUNDING = TG + "Grounding"
TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion"
TG_ANSWER_TYPE = TG + "Answer"
TG_REFLECTION_TYPE = TG + "Reflection"
TG_THOUGHT_TYPE = TG + "Thought"
TG_OBSERVATION_TYPE = TG + "Observation"
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
TG_AGENT_QUESTION = TG + "AgentQuestion"
@ -73,12 +76,16 @@ class ExplainEntity:
if TG_GRAPH_RAG_QUESTION in types or TG_DOC_RAG_QUESTION in types or TG_AGENT_QUESTION in types:
return Question.from_triples(uri, triples, types)
elif TG_GROUNDING in types:
return Grounding.from_triples(uri, triples)
elif TG_EXPLORATION in types:
return Exploration.from_triples(uri, triples)
elif TG_FOCUS in types:
return Focus.from_triples(uri, triples)
elif TG_SYNTHESIS in types:
return Synthesis.from_triples(uri, triples)
elif TG_REFLECTION_TYPE in types:
return Reflection.from_triples(uri, triples)
elif TG_ANALYSIS in types:
return Analysis.from_triples(uri, triples)
elif TG_CONCLUSION in types:
@ -124,16 +131,38 @@ class Question(ExplainEntity):
)
@dataclass
class Grounding(ExplainEntity):
"""Grounding entity - concept decomposition of the query."""
concepts: List[str] = field(default_factory=list)
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Grounding":
concepts = []
for s, p, o in triples:
if p == TG_CONCEPT:
concepts.append(o)
return cls(
uri=uri,
entity_type="grounding",
concepts=concepts
)
@dataclass
class Exploration(ExplainEntity):
"""Exploration entity - edges/chunks retrieved from the knowledge store."""
edge_count: int = 0
chunk_count: int = 0
entities: List[str] = field(default_factory=list)
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Exploration":
edge_count = 0
chunk_count = 0
entities = []
for s, p, o in triples:
if p == TG_EDGE_COUNT:
@ -146,12 +175,15 @@ class Exploration(ExplainEntity):
chunk_count = int(o)
except (ValueError, TypeError):
pass
elif p == TG_ENTITY:
entities.append(o)
return cls(
uri=uri,
entity_type="exploration",
edge_count=edge_count,
chunk_count=chunk_count
chunk_count=chunk_count,
entities=entities
)
@ -180,94 +212,104 @@ class Focus(ExplainEntity):
@dataclass
class Synthesis(ExplainEntity):
"""Synthesis entity - the final answer."""
content: str = ""
document_uri: str = "" # Reference to librarian document
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Synthesis":
content = ""
document_uri = ""
for s, p, o in triples:
if p == TG_CONTENT:
content = o
elif p == TG_DOCUMENT:
if p == TG_DOCUMENT:
document_uri = o
return cls(
uri=uri,
entity_type="synthesis",
content=content,
document_uri=document_uri
)
@dataclass
class Reflection(ExplainEntity):
"""Reflection entity - intermediate commentary (Thought or Observation)."""
document_uri: str = "" # Reference to content in librarian
reflection_type: str = "" # "thought" or "observation"
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Reflection":
document_uri = ""
reflection_type = ""
types = [o for s, p, o in triples if p == RDF_TYPE]
if TG_THOUGHT_TYPE in types:
reflection_type = "thought"
elif TG_OBSERVATION_TYPE in types:
reflection_type = "observation"
for s, p, o in triples:
if p == TG_DOCUMENT:
document_uri = o
return cls(
uri=uri,
entity_type="reflection",
document_uri=document_uri,
reflection_type=reflection_type
)
@dataclass
class Analysis(ExplainEntity):
"""Analysis entity - one think/act/observe cycle (Agent only)."""
thought: str = ""
action: str = ""
arguments: str = "" # JSON string
observation: str = ""
thought_document_uri: str = "" # Reference to thought in librarian
observation_document_uri: str = "" # Reference to observation in librarian
thought_uri: str = "" # URI of thought sub-entity
observation_uri: str = "" # URI of observation sub-entity
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis":
thought = ""
action = ""
arguments = ""
observation = ""
thought_document_uri = ""
observation_document_uri = ""
thought_uri = ""
observation_uri = ""
for s, p, o in triples:
if p == TG_THOUGHT:
thought = o
elif p == TG_ACTION:
if p == TG_ACTION:
action = o
elif p == TG_ARGUMENTS:
arguments = o
elif p == TG_THOUGHT:
thought_uri = o
elif p == TG_OBSERVATION:
observation = o
elif p == TG_THOUGHT_DOCUMENT:
thought_document_uri = o
elif p == TG_OBSERVATION_DOCUMENT:
observation_document_uri = o
observation_uri = o
return cls(
uri=uri,
entity_type="analysis",
thought=thought,
action=action,
arguments=arguments,
observation=observation,
thought_document_uri=thought_document_uri,
observation_document_uri=observation_document_uri
thought_uri=thought_uri,
observation_uri=observation_uri
)
@dataclass
class Conclusion(ExplainEntity):
"""Conclusion entity - final answer (Agent only)."""
answer: str = ""
document_uri: str = "" # Reference to librarian document
@classmethod
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Conclusion":
answer = ""
document_uri = ""
for s, p, o in triples:
if p == TG_ANSWER:
answer = o
elif p == TG_DOCUMENT:
if p == TG_DOCUMENT:
document_uri = o
return cls(
uri=uri,
entity_type="conclusion",
answer=answer,
document_uri=document_uri
)
@ -543,42 +585,29 @@ class ExplainabilityClient:
o_label = self.resolve_label(edge.get("o", ""), user, collection)
return (s_label, p_label, o_label)
def fetch_synthesis_content(
def fetch_document_content(
self,
synthesis: Synthesis,
document_uri: str,
api: Any,
user: Optional[str] = None,
max_content: int = 10000
) -> str:
"""
Fetch the content for a Synthesis entity.
If synthesis has inline content, returns that.
If synthesis has a document_uri, fetches from librarian with retry.
Fetch content from the librarian by document URI.
Args:
synthesis: The Synthesis entity
document_uri: The document URI in the librarian
api: TrustGraph Api instance for librarian access
user: User identifier for librarian
max_content: Maximum content length to return
Returns:
The synthesis content as a string
The document content as a string
"""
# If inline content exists, use it
if synthesis.content:
if len(synthesis.content) > max_content:
return synthesis.content[:max_content] + "... [truncated]"
return synthesis.content
# Otherwise fetch from librarian
if not synthesis.document_uri:
if not document_uri:
return ""
# Extract document ID from URI (e.g., "urn:document:abc123" -> "abc123")
doc_id = synthesis.document_uri
if doc_id.startswith("urn:document:"):
doc_id = doc_id[len("urn:document:"):]
doc_id = document_uri
# Retry fetching from librarian for eventual consistency
for attempt in range(self.max_retries):
@ -603,129 +632,6 @@ class ExplainabilityClient:
return ""
def fetch_conclusion_content(
self,
conclusion: Conclusion,
api: Any,
user: Optional[str] = None,
max_content: int = 10000
) -> str:
"""
Fetch the content for a Conclusion entity (Agent final answer).
If conclusion has inline answer, returns that.
If conclusion has a document_uri, fetches from librarian with retry.
Args:
conclusion: The Conclusion entity
api: TrustGraph Api instance for librarian access
user: User identifier for librarian
max_content: Maximum content length to return
Returns:
The conclusion answer as a string
"""
# If inline answer exists, use it
if conclusion.answer:
if len(conclusion.answer) > max_content:
return conclusion.answer[:max_content] + "... [truncated]"
return conclusion.answer
# Otherwise fetch from librarian
if not conclusion.document_uri:
return ""
# Use document URI directly (it's already a full URN)
doc_id = conclusion.document_uri
# Retry fetching from librarian for eventual consistency
for attempt in range(self.max_retries):
try:
library = api.library()
content_bytes = library.get_document_content(user=user, id=doc_id)
# Decode as text
try:
content = content_bytes.decode('utf-8')
if len(content) > max_content:
return content[:max_content] + "... [truncated]"
return content
except UnicodeDecodeError:
return f"[Binary: {len(content_bytes)} bytes]"
except Exception as e:
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
continue
return f"[Error fetching content: {e}]"
return ""
def fetch_analysis_content(
self,
analysis: Analysis,
api: Any,
user: Optional[str] = None,
max_content: int = 10000
) -> None:
"""
Fetch thought and observation content for an Analysis entity.
If analysis has inline content, uses that.
If analysis has document URIs, fetches from librarian with retry.
Modifies the analysis object in place.
Args:
analysis: The Analysis entity (modified in place)
api: TrustGraph Api instance for librarian access
user: User identifier for librarian
max_content: Maximum content length to return
"""
# Fetch thought if needed
if not analysis.thought and analysis.thought_document_uri:
doc_id = analysis.thought_document_uri
for attempt in range(self.max_retries):
try:
library = api.library()
content_bytes = library.get_document_content(user=user, id=doc_id)
try:
content = content_bytes.decode('utf-8')
if len(content) > max_content:
analysis.thought = content[:max_content] + "... [truncated]"
else:
analysis.thought = content
break
except UnicodeDecodeError:
analysis.thought = f"[Binary: {len(content_bytes)} bytes]"
break
except Exception as e:
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
continue
analysis.thought = f"[Error fetching thought: {e}]"
# Fetch observation if needed
if not analysis.observation and analysis.observation_document_uri:
doc_id = analysis.observation_document_uri
for attempt in range(self.max_retries):
try:
library = api.library()
content_bytes = library.get_document_content(user=user, id=doc_id)
try:
content = content_bytes.decode('utf-8')
if len(content) > max_content:
analysis.observation = content[:max_content] + "... [truncated]"
else:
analysis.observation = content
break
except UnicodeDecodeError:
analysis.observation = f"[Binary: {len(content_bytes)} bytes]"
break
except Exception as e:
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
continue
analysis.observation = f"[Error fetching observation: {e}]"
def fetch_graphrag_trace(
self,
@ -739,7 +645,7 @@ class ExplainabilityClient:
"""
Fetch the complete GraphRAG trace starting from a question URI.
Follows the provenance chain: Question -> Exploration -> Focus -> Synthesis
Follows the provenance chain: Question -> Grounding -> Exploration -> Focus -> Synthesis
Args:
question_uri: The question entity URI
@ -750,13 +656,14 @@ class ExplainabilityClient:
max_content: Maximum content length for synthesis
Returns:
Dict with question, exploration, focus, synthesis entities
Dict with question, grounding, exploration, focus, synthesis entities
"""
if graph is None:
graph = "urn:graph:retrieval"
trace = {
"question": None,
"grounding": None,
"exploration": None,
"focus": None,
"synthesis": None,
@ -768,8 +675,8 @@ class ExplainabilityClient:
return trace
trace["question"] = question
# Find exploration: ?exploration prov:wasGeneratedBy question_uri
exploration_triples = self.flow.triples_query(
# Find grounding: ?grounding prov:wasGeneratedBy question_uri
grounding_triples = self.flow.triples_query(
p=PROV_WAS_GENERATED_BY,
o=question_uri,
g=graph,
@ -778,6 +685,30 @@ class ExplainabilityClient:
limit=10
)
if grounding_triples:
grounding_uris = [
extract_term_value(t.get("s", {}))
for t in grounding_triples
]
for gnd_uri in grounding_uris:
grounding = self.fetch_entity(gnd_uri, graph, user, collection)
if isinstance(grounding, Grounding):
trace["grounding"] = grounding
break
if not trace["grounding"]:
return trace
# Find exploration: ?exploration prov:wasDerivedFrom grounding_uri
exploration_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=trace["grounding"].uri,
g=graph,
user=user,
collection=collection,
limit=10
)
if exploration_triples:
exploration_uris = [
extract_term_value(t.get("s", {}))
@ -834,11 +765,6 @@ class ExplainabilityClient:
for synth_uri in synthesis_uris:
synthesis = self.fetch_entity(synth_uri, graph, user, collection)
if isinstance(synthesis, Synthesis):
# Fetch content if needed
if api and not synthesis.content and synthesis.document_uri:
synthesis.content = self.fetch_synthesis_content(
synthesis, api, user, max_content
)
trace["synthesis"] = synthesis
break
@ -928,11 +854,6 @@ class ExplainabilityClient:
for synth_uri in synthesis_uris:
synthesis = self.fetch_entity(synth_uri, graph, user, collection)
if isinstance(synthesis, Synthesis):
# Fetch content if needed
if api and not synthesis.content and synthesis.document_uri:
synthesis.content = self.fetch_synthesis_content(
synthesis, api, user, max_content
)
trace["synthesis"] = synthesis
break
@ -978,20 +899,43 @@ class ExplainabilityClient:
return trace
trace["question"] = question
# Follow the chain of wasDerivedFrom
# Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after
current_uri = session_uri
is_first = True
max_iterations = 50 # Safety limit
for _ in range(max_iterations):
# Find entity derived from current
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
# First hop uses wasGeneratedBy (entity←activity),
# subsequent hops use wasDerivedFrom (entity←entity)
if is_first:
derived_triples = self.flow.triples_query(
p=PROV_WAS_GENERATED_BY,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
# Fall back to wasDerivedFrom for backwards compatibility
if not derived_triples:
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
is_first = False
else:
derived_triples = self.flow.triples_query(
p=PROV_WAS_DERIVED_FROM,
o=current_uri,
g=graph,
user=user,
collection=collection,
limit=10
)
if not derived_triples:
break
@ -1003,19 +947,9 @@ class ExplainabilityClient:
entity = self.fetch_entity(derived_uri, graph, user, collection)
if isinstance(entity, Analysis):
# Fetch thought/observation content from librarian if needed
if api:
self.fetch_analysis_content(
entity, api, user=user, max_content=max_content
)
trace["iterations"].append(entity)
current_uri = derived_uri
elif isinstance(entity, Conclusion):
# Fetch answer content from librarian if needed
if api and not entity.answer and entity.document_uri:
entity.answer = self.fetch_conclusion_content(
entity, api, user=user, max_content=max_content
)
trace["conclusion"] = entity
break
else:

View file

@ -1,6 +1,6 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL, TRIPLE
from .. knowledge import Uri, Literal
@ -22,9 +22,11 @@ def to_value(x):
def from_value(x):
"""Convert Uri, Literal, or string to schema Term."""
"""Convert Uri, Literal, string, or Term to schema Term."""
if x is None:
return None
if isinstance(x, Term):
return x
if isinstance(x, Uri):
return Term(type=IRI, iri=str(x))
elif isinstance(x, Literal):
@ -41,7 +43,7 @@ def from_value(x):
class TriplesClient(RequestResponse):
async def query(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default",
timeout=30):
timeout=30, g=None):
resp = await self.request(
TriplesQueryRequest(
@ -51,6 +53,7 @@ class TriplesClient(RequestResponse):
limit = limit,
user = user,
collection = collection,
g = g,
),
timeout=timeout
)
@ -68,7 +71,7 @@ class TriplesClient(RequestResponse):
async def query_stream(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default",
batch_size=20, timeout=30,
batch_callback=None):
batch_callback=None, g=None):
"""
Streaming triple query - calls callback for each batch as it arrives.
@ -80,6 +83,8 @@ class TriplesClient(RequestResponse):
batch_size: Triples per batch
timeout: Request timeout in seconds
batch_callback: Async callback(batch, is_final) called for each batch
g: Graph filter. ""=default graph only, None=all graphs,
or a specific graph IRI.
Returns:
List[Triple]: All triples (flattened) if no callback provided
@ -112,6 +117,7 @@ class TriplesClient(RequestResponse):
collection=collection,
streaming=True,
batch_size=batch_size,
g=g,
),
timeout=timeout,
recipient=recipient,

View file

@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2)),
edge_limit=int(data.get("edge-limit", 25)),
streaming=data.get("streaming", False)
)
@ -96,6 +97,7 @@ class GraphRagRequestTranslator(MessageTranslator):
"triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length,
"edge-limit": obj.edge_limit,
"streaming": getattr(obj, "streaming", False)
}

View file

@ -42,15 +42,19 @@ from . uris import (
agent_uri,
# Query-time provenance URIs (GraphRAG)
question_uri,
grounding_uri,
exploration_uri,
focus_uri,
synthesis_uri,
# Agent provenance URIs
agent_session_uri,
agent_iteration_uri,
agent_thought_uri,
agent_observation_uri,
agent_final_uri,
# Document RAG provenance URIs
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_synthesis_uri,
)
@ -74,18 +78,19 @@ from . namespaces import (
# Extraction provenance entity types
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_CONTENT,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
# Unifying types
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
# Question subtypes (to distinguish retrieval mechanism)
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
# Agent provenance predicates
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
# Agent document references
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
# Document reference predicate
TG_DOCUMENT,
# Named graphs
@ -99,6 +104,7 @@ from . triples import (
subgraph_provenance_triples,
# Query-time provenance triple builders (GraphRAG)
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
@ -139,15 +145,19 @@ __all__ = [
"agent_uri",
# Query-time provenance URIs
"question_uri",
"grounding_uri",
"exploration_uri",
"focus_uri",
"synthesis_uri",
# Agent provenance URIs
"agent_session_uri",
"agent_iteration_uri",
"agent_thought_uri",
"agent_observation_uri",
"agent_final_uri",
# Document RAG provenance URIs
"docrag_question_uri",
"docrag_grounding_uri",
"docrag_exploration_uri",
"docrag_synthesis_uri",
# Namespaces
@ -164,18 +174,19 @@ __all__ = [
# Extraction provenance entity types
"TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE",
# Query-time provenance predicates (GraphRAG)
"TG_QUERY", "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_CONTENT",
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
# Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Explainability entity types
"TG_QUESTION", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
"TG_ANALYSIS", "TG_CONCLUSION",
# Unifying types
"TG_ANSWER_TYPE", "TG_REFLECTION_TYPE", "TG_THOUGHT_TYPE", "TG_OBSERVATION_TYPE",
# Question subtypes
"TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION",
# Agent provenance predicates
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER",
# Agent document references
"TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT",
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION",
# Document reference predicate
"TG_DOCUMENT",
# Named graphs
@ -186,6 +197,7 @@ __all__ = [
"subgraph_provenance_triples",
# Query-time provenance triple builders (GraphRAG)
"question_triples",
"grounding_triples",
"exploration_triples",
"focus_triples",
"synthesis_triples",

View file

@ -15,10 +15,11 @@ from .. schema import Triple, Term, IRI, LITERAL
from . namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_AGENT_QUESTION,
)
@ -73,12 +74,13 @@ def agent_session_triples(
def agent_iteration_triples(
iteration_uri: str,
parent_uri: str,
thought: str = "",
question_uri: Optional[str] = None,
previous_uri: Optional[str] = None,
action: str = "",
arguments: Dict[str, Any] = None,
observation: str = "",
thought_uri: Optional[str] = None,
thought_document_id: Optional[str] = None,
observation_uri: Optional[str] = None,
observation_document_id: Optional[str] = None,
) -> List[Triple]:
"""
@ -86,19 +88,22 @@ def agent_iteration_triples(
Creates:
- Entity declaration with tg:Analysis type
- wasDerivedFrom link to parent (previous iteration or session)
- Thought, action, arguments, and observation data
- Document references for thought/observation when stored in librarian
- wasGeneratedBy link to question (if first iteration)
- wasDerivedFrom link to previous iteration (if not first)
- Action and arguments metadata
- Thought sub-entity (tg:Reflection, tg:Thought) with librarian document
- Observation sub-entity (tg:Reflection, tg:Observation) with librarian document
Args:
iteration_uri: URI of this iteration (from agent_iteration_uri)
parent_uri: URI of the parent (previous iteration or session)
thought: The agent's reasoning/thought (used if thought_document_id not provided)
question_uri: URI of the question activity (for first iteration)
previous_uri: URI of the previous iteration (for subsequent iterations)
action: The tool/action name
arguments: Arguments passed to the tool (will be JSON-encoded)
observation: The result/observation from the tool (used if observation_document_id not provided)
thought_document_id: Optional document URI for thought in librarian (preferred)
observation_document_id: Optional document URI for observation in librarian (preferred)
thought_uri: URI for the thought sub-entity
thought_document_id: Document URI for thought in librarian
observation_uri: URI for the observation sub-entity
observation_document_id: Document URI for observation in librarian
Returns:
List of Triple objects
@ -110,45 +115,70 @@ def agent_iteration_triples(
_triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)),
_triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")),
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
_triple(iteration_uri, TG_ACTION, _literal(action)),
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
]
# Thought: use document reference or inline
if thought_document_id:
triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id)))
elif thought:
triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought)))
if question_uri:
triples.append(
_triple(iteration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri))
)
elif previous_uri:
triples.append(
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri))
)
# Observation: use document reference or inline
if observation_document_id:
triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id)))
elif observation:
triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation)))
# Thought sub-entity
if thought_uri:
triples.extend([
_triple(iteration_uri, TG_THOUGHT, _iri(thought_uri)),
_triple(thought_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)),
_triple(thought_uri, RDF_TYPE, _iri(TG_THOUGHT_TYPE)),
_triple(thought_uri, RDFS_LABEL, _literal("Thought")),
_triple(thought_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)),
])
if thought_document_id:
triples.append(
_triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id))
)
# Observation sub-entity
if observation_uri:
triples.extend([
_triple(iteration_uri, TG_OBSERVATION, _iri(observation_uri)),
_triple(observation_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)),
_triple(observation_uri, RDF_TYPE, _iri(TG_OBSERVATION_TYPE)),
_triple(observation_uri, RDFS_LABEL, _literal("Observation")),
_triple(observation_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)),
])
if observation_document_id:
triples.append(
_triple(observation_uri, TG_DOCUMENT, _iri(observation_document_id))
)
return triples
def agent_final_triples(
final_uri: str,
parent_uri: str,
answer: str = "",
question_uri: Optional[str] = None,
previous_uri: Optional[str] = None,
document_id: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for an agent final answer (Conclusion).
Creates:
- Entity declaration with tg:Conclusion type
- wasDerivedFrom link to parent (last iteration or session)
- Either document reference (if document_id provided) or inline answer
- Entity declaration with tg:Conclusion and tg:Answer types
- wasGeneratedBy link to question (if no iterations)
- wasDerivedFrom link to last iteration (if iterations exist)
- Document reference to librarian
Args:
final_uri: URI of the final answer (from agent_final_uri)
parent_uri: URI of the parent (last iteration or session if no iterations)
answer: The final answer text (used if document_id not provided)
document_id: Optional document URI in librarian (preferred)
question_uri: URI of the question activity (if no iterations)
previous_uri: URI of the last iteration (if iterations exist)
document_id: Librarian document ID for the answer content
Returns:
List of Triple objects
@ -156,15 +186,20 @@ def agent_final_triples(
triples = [
_triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)),
_triple(final_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
_triple(final_uri, RDFS_LABEL, _literal("Conclusion")),
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
]
if question_uri:
triples.append(
_triple(final_uri, PROV_WAS_GENERATED_BY, _iri(question_uri))
)
elif previous_uri:
triples.append(
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri))
)
if document_id:
# Store reference to document in librarian (as IRI)
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
elif answer:
# Fallback: store inline answer
triples.append(_triple(final_uri, TG_ANSWER, _literal(answer)))
return triples

View file

@ -60,11 +60,12 @@ TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength"
# Query-time provenance predicates (GraphRAG)
TG_QUERY = TG + "query"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document" # Reference to document in librarian
# Query-time provenance predicates (DocumentRAG)
@ -79,27 +80,29 @@ TG_SUBGRAPH_TYPE = TG + "Subgraph"
# Explainability entity types (shared)
TG_QUESTION = TG + "Question"
TG_GROUNDING = TG + "Grounding"
TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion"
# Unifying types for answer and intermediate commentary
TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion)
TG_REFLECTION_TYPE = TG + "Reflection" # Intermediate commentary (Thought, Observation)
TG_THOUGHT_TYPE = TG + "Thought" # Agent reasoning
TG_OBSERVATION_TYPE = TG + "Observation" # Agent tool result
# Question subtypes (to distinguish retrieval mechanism)
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
TG_AGENT_QUESTION = TG + "AgentQuestion"
# Agent provenance predicates
TG_THOUGHT = TG + "thought"
TG_THOUGHT = TG + "thought" # Links iteration to thought sub-entity
TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
# Agent document references (for librarian storage)
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity
# Named graph URIs for RDF datasets
# These separate different types of data while keeping them in the same collection

View file

@ -20,12 +20,15 @@ from . namespaces import (
# Extraction provenance entity types
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_CONTENT,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
# Unifying types
TG_ANSWER_TYPE,
# Question subtypes
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
)
@ -347,35 +350,78 @@ def question_triples(
]
def grounding_triples(
grounding_uri: str,
question_uri: str,
concepts: List[str],
) -> List[Triple]:
"""
Build triples for a grounding entity (concept decomposition of query).
Creates:
- Entity declaration for grounding
- wasGeneratedBy link to question
- Concept literals for each extracted concept
Args:
grounding_uri: URI of the grounding entity (from grounding_uri)
question_uri: URI of the parent question
concepts: List of concept strings extracted from the query
Returns:
List of Triple objects
"""
triples = [
_triple(grounding_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(grounding_uri, RDF_TYPE, _iri(TG_GROUNDING)),
_triple(grounding_uri, RDFS_LABEL, _literal("Grounding")),
_triple(grounding_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)),
]
for concept in concepts:
triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept)))
return triples
def exploration_triples(
exploration_uri: str,
question_uri: str,
grounding_uri: str,
edge_count: int,
entities: Optional[List[str]] = None,
) -> List[Triple]:
"""
Build triples for an exploration entity (all edges retrieved from subgraph).
Creates:
- Entity declaration for exploration
- wasGeneratedBy link to question
- wasDerivedFrom link to grounding
- Edge count metadata
- Entity IRIs for each seed entity
Args:
exploration_uri: URI of the exploration entity (from exploration_uri)
question_uri: URI of the parent question
grounding_uri: URI of the parent grounding entity
edge_count: Number of edges retrieved
entities: Optional list of seed entity URIs
Returns:
List of Triple objects
"""
return [
triples = [
_triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)),
_triple(exploration_uri, RDFS_LABEL, _literal("Exploration")),
_triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)),
_triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)),
_triple(exploration_uri, TG_EDGE_COUNT, _literal(edge_count)),
]
if entities:
for entity in entities:
triples.append(_triple(exploration_uri, TG_ENTITY, _iri(entity)))
return triples
def _quoted_triple(s: str, p: str, o: str) -> Term:
"""Create a quoted triple term (RDF-star) from string values."""
@ -454,22 +500,20 @@ def focus_triples(
def synthesis_triples(
synthesis_uri: str,
focus_uri: str,
answer_text: str = "",
document_id: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a synthesis entity (final answer text).
Build triples for a synthesis entity (final answer).
Creates:
- Entity declaration for synthesis
- Entity declaration for synthesis with tg:Answer type
- wasDerivedFrom link to focus
- Either document reference (if document_id provided) or inline content
- Document reference to librarian
Args:
synthesis_uri: URI of the synthesis entity (from synthesis_uri)
focus_uri: URI of the parent focus entity
answer_text: The synthesized answer text (used if no document_id)
document_id: Optional librarian document ID (preferred over inline content)
document_id: Librarian document ID for the answer content
Returns:
List of Triple objects
@ -477,16 +521,13 @@ def synthesis_triples(
triples = [
_triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
_triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")),
_triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(focus_uri)),
]
if document_id:
# Store reference to document in librarian (as IRI)
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
elif answer_text:
# Fallback: store inline content
triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text)))
return triples
@ -533,7 +574,7 @@ def docrag_question_triples(
def docrag_exploration_triples(
exploration_uri: str,
question_uri: str,
grounding_uri: str,
chunk_count: int,
chunk_ids: Optional[List[str]] = None,
) -> List[Triple]:
@ -542,12 +583,12 @@ def docrag_exploration_triples(
Creates:
- Entity declaration with tg:Exploration type
- wasGeneratedBy link to question
- wasDerivedFrom link to grounding
- Chunk count and optional chunk references
Args:
exploration_uri: URI of the exploration entity
question_uri: URI of the parent question
grounding_uri: URI of the parent grounding entity
chunk_count: Number of chunks retrieved
chunk_ids: Optional list of chunk URIs/IDs
@ -558,7 +599,7 @@ def docrag_exploration_triples(
_triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)),
_triple(exploration_uri, RDFS_LABEL, _literal("Exploration")),
_triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)),
_triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)),
_triple(exploration_uri, TG_CHUNK_COUNT, _literal(chunk_count)),
]
@ -573,22 +614,20 @@ def docrag_exploration_triples(
def docrag_synthesis_triples(
synthesis_uri: str,
exploration_uri: str,
answer_text: str = "",
document_id: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a document RAG synthesis entity (final answer).
Creates:
- Entity declaration with tg:Synthesis type
- Entity declaration with tg:Synthesis and tg:Answer types
- wasDerivedFrom link to exploration (skips focus step)
- Either document reference or inline content
- Document reference to librarian
Args:
synthesis_uri: URI of the synthesis entity
exploration_uri: URI of the parent exploration entity
answer_text: The synthesized answer text (used if no document_id)
document_id: Optional librarian document ID (preferred over inline content)
document_id: Librarian document ID for the answer content
Returns:
List of Triple objects
@ -596,13 +635,12 @@ def docrag_synthesis_triples(
triples = [
_triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)),
_triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
_triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")),
_triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
if document_id:
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
elif answer_text:
triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text)))
return triples

View file

@ -68,6 +68,7 @@ def agent_uri(component_name: str) -> str:
#
# Terminology:
# Question - What was asked, the anchor for everything
# Grounding - Decomposing the question into concepts
# Exploration - Casting wide, what do we know about this space
# Focus - Closing down, what's actually relevant here
# Synthesis - Weaving the relevant pieces into an answer
@ -87,6 +88,19 @@ def question_uri(session_id: str = None) -> str:
return f"urn:trustgraph:question:{session_id}"
def grounding_uri(session_id: str) -> str:
"""
Generate URI for a grounding entity (concept decomposition of query).
Args:
session_id: The session UUID (same as question_uri).
Returns:
URN in format: urn:trustgraph:prov:grounding:{uuid}
"""
return f"urn:trustgraph:prov:grounding:{session_id}"
def exploration_uri(session_id: str) -> str:
"""
Generate URI for an exploration entity (edges retrieved from subgraph).
@ -173,6 +187,34 @@ def agent_iteration_uri(session_id: str, iteration_num: int) -> str:
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}"
def agent_thought_uri(session_id: str, iteration_num: int) -> str:
"""
Generate URI for an agent thought sub-entity.
Args:
session_id: The session UUID.
iteration_num: 1-based iteration number.
Returns:
URN in format: urn:trustgraph:agent:{uuid}/i{num}/thought
"""
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
def agent_observation_uri(session_id: str, iteration_num: int) -> str:
"""
Generate URI for an agent observation sub-entity.
Args:
session_id: The session UUID.
iteration_num: 1-based iteration number.
Returns:
URN in format: urn:trustgraph:agent:{uuid}/i{num}/observation
"""
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
def agent_final_uri(session_id: str) -> str:
"""
Generate URI for an agent final answer.
@ -205,6 +247,19 @@ def docrag_question_uri(session_id: str = None) -> str:
return f"urn:trustgraph:docrag:{session_id}"
def docrag_grounding_uri(session_id: str) -> str:
"""
Generate URI for a document RAG grounding entity (concept decomposition).
Args:
session_id: The session UUID.
Returns:
URN in format: urn:trustgraph:docrag:{uuid}/grounding
"""
return f"urn:trustgraph:docrag:{session_id}/grounding"
def docrag_exploration_uri(session_id: str) -> str:
"""
Generate URI for a document RAG exploration entity (chunks retrieved).

View file

@ -25,6 +25,8 @@ from . namespaces import (
TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL,
TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_CONCEPT, TG_ENTITY, TG_GROUNDING,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
)
@ -80,6 +82,11 @@ TG_CLASS_LABELS = [
_label_triple(TG_PAGE_TYPE, "Page"),
_label_triple(TG_CHUNK_TYPE, "Chunk"),
_label_triple(TG_SUBGRAPH_TYPE, "Subgraph"),
_label_triple(TG_GROUNDING, "Grounding"),
_label_triple(TG_ANSWER_TYPE, "Answer"),
_label_triple(TG_REFLECTION_TYPE, "Reflection"),
_label_triple(TG_THOUGHT_TYPE, "Thought"),
_label_triple(TG_OBSERVATION_TYPE, "Observation"),
]
# TrustGraph predicate labels
@ -100,6 +107,8 @@ TG_PREDICATE_LABELS = [
_label_triple(TG_SOURCE_TEXT, "source text"),
_label_triple(TG_SOURCE_CHAR_OFFSET, "source character offset"),
_label_triple(TG_SOURCE_CHAR_LENGTH, "source character length"),
_label_triple(TG_CONCEPT, "concept"),
_label_triple(TG_ENTITY, "entity"),
]

View file

@ -15,6 +15,7 @@ class GraphRagQuery:
triple_limit: int = 0
max_subgraph_size: int = 0
max_path_length: int = 0
edge_limit: int = 0
streaming: bool = False
@dataclass

View file

@ -202,16 +202,17 @@ def question_explainable(
elif isinstance(entity, Analysis):
print(f"\n [iteration] {prov_id}", file=sys.stderr)
if entity.thought:
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
print(f" Thought: {thought_short}", file=sys.stderr)
if entity.action:
print(f" Action: {entity.action}", file=sys.stderr)
if entity.thought_uri:
print(f" Thought: {entity.thought_uri}", file=sys.stderr)
if entity.observation_uri:
print(f" Observation: {entity.observation_uri}", file=sys.stderr)
elif isinstance(entity, Conclusion):
print(f"\n [conclusion] {prov_id}", file=sys.stderr)
if entity.answer:
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr)
if entity.document_uri:
print(f" Document: {entity.document_uri}", file=sys.stderr)
else:
if debug:

View file

@ -11,6 +11,7 @@ from trustgraph.api import (
RAGChunk,
ProvenanceEvent,
Question,
Grounding,
Exploration,
Synthesis,
)
@ -68,6 +69,12 @@ def question_explainable(
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Grounding):
print(f"\n [grounding] {prov_id}", file=sys.stderr)
if entity.concepts:
for concept in entity.concepts:
print(f" Concept: {concept}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.chunk_count:
@ -75,8 +82,8 @@ def question_explainable(
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
if entity.document_uri:
print(f" Document: {entity.document_uri}", file=sys.stderr)
else:
if debug:

View file

@ -14,6 +14,7 @@ from trustgraph.api import (
RAGChunk,
ProvenanceEvent,
Question,
Grounding,
Exploration,
Focus,
Synthesis,
@ -31,11 +32,13 @@ default_max_path_length = 2
# Provenance predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_CONTENT = TG + "content"
TG_DOCUMENT = TG + "document"
TG_CONTAINS = TG + "contains"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
@ -47,6 +50,8 @@ def _get_event_type(prov_id):
"""Extract event type from provenance_id"""
if "question" in prov_id:
return "question"
elif "grounding" in prov_id:
return "grounding"
elif "exploration" in prov_id:
return "exploration"
elif "focus" in prov_id:
@ -68,8 +73,16 @@ def _format_provenance_details(event_type, triples):
elif p == PROV_STARTED_AT_TIME:
lines.append(f" Time: {o}")
elif event_type == "grounding":
# Show extracted concepts
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
if concepts:
lines.append(f" Concepts: {len(concepts)}")
for concept in concepts:
lines.append(f" - {concept}")
elif event_type == "exploration":
# Show edge count
# Show edge count (seed entities resolved separately with labels)
for s, p, o in triples:
if p == TG_EDGE_COUNT:
lines.append(f" Edges explored: {o}")
@ -85,10 +98,10 @@ def _format_provenance_details(event_type, triples):
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
elif event_type == "synthesis":
# Show content length (not full content - it's already streamed)
# Show document reference (content already streamed)
for s, p, o in triples:
if p == TG_CONTENT:
lines.append(f" Synthesis length: {len(o)} chars")
if p == TG_DOCUMENT:
lines.append(f" Document: {o}")
return lines
@ -542,6 +555,18 @@ async def _question_explainable(
for line in details:
print(line, file=sys.stderr)
# For exploration events, resolve entity labels
if event_type == "exploration":
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
if entity_iris:
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
for iri in entity_iris:
label = await _query_label(
ws_url, flow_id, iri, user, collection,
label_cache, debug=debug
)
print(f" - {label}", file=sys.stderr)
# For focus events, query each edge selection for details
if event_type == "focus":
for s, p, o in triples:
@ -660,10 +685,22 @@ def _question_explainable_api(
if entity.timestamp:
print(f" Time: {entity.timestamp}", file=sys.stderr)
elif isinstance(entity, Grounding):
print(f"\n [grounding] {prov_id}", file=sys.stderr)
if entity.concepts:
print(f" Concepts: {len(entity.concepts)}", file=sys.stderr)
for concept in entity.concepts:
print(f" - {concept}", file=sys.stderr)
elif isinstance(entity, Exploration):
print(f"\n [exploration] {prov_id}", file=sys.stderr)
if entity.edge_count:
print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
if entity.entities:
print(f" Seed entities: {len(entity.entities)}", file=sys.stderr)
for ent in entity.entities:
label = explain_client.resolve_label(ent, user, collection)
print(f" - {label}", file=sys.stderr)
elif isinstance(entity, Focus):
print(f"\n [focus] {prov_id}", file=sys.stderr)
@ -691,8 +728,8 @@ def _question_explainable_api(
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
if entity.content:
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
if entity.document_uri:
print(f" Document: {entity.document_uri}", file=sys.stderr)
else:
if debug:
@ -848,7 +885,7 @@ def main():
parser.add_argument(
'-x', '--explainable',
action='store_true',
help='Show provenance events: Question, Exploration, Focus, Synthesis (implies streaming)'
help='Show provenance events: Question, Grounding, Exploration, Focus, Synthesis (implies streaming)'
)
parser.add_argument(

View file

@ -31,6 +31,8 @@ from ... schema import librarian_request_queue, librarian_response_queue
from trustgraph.provenance import (
agent_session_uri,
agent_iteration_uri,
agent_thought_uri,
agent_observation_uri,
agent_final_uri,
agent_session_triples,
agent_iteration_triples,
@ -624,11 +626,13 @@ class Processor(AgentService):
# Emit final answer provenance triples
final_uri = agent_final_uri(session_id)
# Parent is last iteration, or session if no iterations
# No iterations: link to question; otherwise: link to last iteration
if iteration_num > 1:
parent_uri = agent_iteration_uri(session_id, iteration_num - 1)
final_question_uri = None
final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
else:
parent_uri = session_uri
final_question_uri = session_uri
final_previous_uri = None
# Save answer to librarian
answer_doc_id = None
@ -648,8 +652,9 @@ class Processor(AgentService):
final_triples = set_graph(
agent_final_triples(
final_uri, parent_uri,
answer="" if answer_doc_id else f,
final_uri,
question_uri=final_question_uri,
previous_uri=final_previous_uri,
document_id=answer_doc_id,
),
GRAPH_RETRIEVAL
@ -707,11 +712,13 @@ class Processor(AgentService):
# Emit iteration provenance triples
iteration_uri = agent_iteration_uri(session_id, iteration_num)
# Parent is previous iteration, or session if this is first iteration
# First iteration links to question, subsequent to previous
if iteration_num > 1:
parent_uri = agent_iteration_uri(session_id, iteration_num - 1)
iter_question_uri = None
iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
else:
parent_uri = session_uri
iter_question_uri = session_uri
iter_previous_uri = None
# Save thought to librarian
thought_doc_id = None
@ -745,15 +752,19 @@ class Processor(AgentService):
logger.warning(f"Failed to save observation to librarian: {e}")
observation_doc_id = None
thought_entity_uri = agent_thought_uri(session_id, iteration_num)
observation_entity_uri = agent_observation_uri(session_id, iteration_num)
iter_triples = set_graph(
agent_iteration_triples(
iteration_uri,
parent_uri,
thought="" if thought_doc_id else act.thought,
question_uri=iter_question_uri,
previous_uri=iter_previous_uri,
action=act.name,
arguments=act.arguments,
observation="" if observation_doc_id else act.observation,
thought_uri=thought_entity_uri if thought_doc_id else None,
thought_document_id=thought_doc_id,
observation_uri=observation_entity_uri if observation_doc_id else None,
observation_document_id=observation_doc_id,
),
GRAPH_RETRIEVAL

View file

@ -7,9 +7,11 @@ from datetime import datetime
# Provenance imports
from trustgraph.provenance import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_synthesis_uri,
docrag_question_triples,
grounding_triples,
docrag_exploration_triples,
docrag_synthesis_triples,
set_graph,
@ -33,39 +35,79 @@ class Query:
self.verbose = verbose
self.doc_limit = doc_limit
async def get_vector(self, query):
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
# Fallback to raw query if no concepts extracted
if not concepts:
concepts = [query]
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
return concepts
async def get_vectors(self, concepts):
"""Compute embeddings for a list of concepts."""
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed([query])
qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose:
logger.debug("Embeddings computed")
# Return the vector set for the first (only) text
return qembeds[0] if qembeds else []
return qembeds
async def get_docs(self, query):
async def get_docs(self, concepts):
"""
Get documents (chunks) matching the query.
Get documents (chunks) matching the extracted concepts.
Returns:
tuple: (docs, chunk_ids) where:
- docs: list of document content strings
- chunk_ids: list of chunk IDs that were successfully fetched
"""
vectors = await self.get_vector(query)
vectors = await self.get_vectors(concepts)
if self.verbose:
logger.debug("Getting chunks from embeddings store...")
# Get chunk matches from embeddings store
chunk_matches = await self.rag.doc_embeddings_client.query(
vector=vectors, limit=self.doc_limit,
user=self.user, collection=self.collection,
# Query chunk matches for each concept concurrently
per_concept_limit = max(
1, self.doc_limit // len(vectors)
)
async def query_concept(vec):
return await self.rag.doc_embeddings_client.query(
vector=vec, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
results = await asyncio.gather(
*[query_concept(v) for v in vectors]
)
# Deduplicate chunk matches by chunk_id
seen = set()
chunk_matches = []
for matches in results:
for match in matches:
if match.chunk_id and match.chunk_id not in seen:
seen.add(match.chunk_id)
chunk_matches.append(match)
if self.verbose:
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
@ -133,6 +175,7 @@ class DocumentRag:
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id)
@ -151,12 +194,23 @@ class DocumentRag:
doc_limit=doc_limit
)
docs, chunk_ids = await q.get_docs(query)
# Extract concepts from query (grounding step)
concepts = await q.extract_concepts(query)
# Emit grounding explainability after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved
if explain_callback:
exp_triples = set_graph(
docrag_exploration_triples(exp_uri, q_uri, len(chunk_ids), chunk_ids),
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
GRAPH_RETRIEVAL
)
await explain_callback(exp_triples, exp_uri)
@ -196,9 +250,8 @@ class DocumentRag:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian if callback provided
# Save answer to librarian
if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
@ -206,13 +259,11 @@ class DocumentRag:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None # Fall back to inline content
synthesis_doc_id = None
# Generate triples with document reference or inline content
syn_triples = set_graph(
docrag_synthesis_triples(
syn_uri, exp_uri,
answer_text="" if synthesis_doc_id else answer_text,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL

View file

@ -8,20 +8,23 @@ import uuid
from collections import OrderedDict
from datetime import datetime
from ... schema import IRI, LITERAL
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
# Provenance imports
from trustgraph.provenance import (
question_uri,
grounding_uri as make_grounding_uri,
exploration_uri as make_exploration_uri,
focus_uri as make_focus_uri,
synthesis_uri as make_synthesis_uri,
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
set_graph,
GRAPH_RETRIEVAL,
GRAPH_RETRIEVAL, GRAPH_SOURCE,
TG_CONTAINS, PROV_WAS_DERIVED_FROM,
)
# Module logger
@ -47,6 +50,8 @@ def edge_id(s, p, o):
edge_str = f"{s}|{p}|{o}"
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
@ -105,42 +110,88 @@ class Query:
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
async def get_vector(self, query):
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
# Fall back to raw query if extraction returns nothing
return concepts if concepts else [query]
async def get_vectors(self, concepts):
"""Embed multiple concepts concurrently."""
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed([query])
qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose:
logger.debug("Done.")
# Return the vector set for the first (only) text
return qembeds[0] if qembeds else []
return qembeds
async def get_entities(self, query):
"""
Extract concepts from query, embed them, and retrieve matching entities.
vectors = await self.get_vector(query)
Returns:
tuple: (entities, concepts) where entities is a list of entity URI
strings and concepts is the list of concept strings extracted
from the query.
"""
concepts = await self.extract_concepts(query)
vectors = await self.get_vectors(concepts)
if self.verbose:
logger.debug("Getting entities...")
entity_matches = await self.rag.graph_embeddings_client.query(
vector=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection,
# Query entity matches for each concept concurrently
per_concept_limit = max(
1, self.entity_limit // len(vectors)
)
entities = [
term_to_string(e.entity)
for e in entity_matches
entity_tasks = [
self.rag.graph_embeddings_client.query(
vector=v, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
for v in vectors
]
results = await asyncio.gather(*entity_tasks, return_exceptions=True)
# Deduplicate while preserving order
seen = set()
entities = []
for result in results:
if isinstance(result, Exception) or not result:
continue
for e in result:
entity = term_to_string(e.entity)
if entity not in seen:
seen.add(entity)
entities.append(entity)
if self.verbose:
logger.debug("Entities:")
for ent in entities:
logger.debug(f" {ent}")
return entities
return entities, concepts
async def maybe_label(self, e):
@ -156,6 +207,7 @@ class Query:
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection,
g="",
)
if len(res) == 0:
@ -177,19 +229,19 @@ class Query:
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
batch_size=20,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
batch_size=20,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection,
batch_size=20,
batch_size=20, g="",
)
])
@ -262,8 +314,16 @@ class Query:
subgraph.update(batch_result)
async def get_subgraph(self, query):
"""
Get subgraph by extracting concepts, finding entities, and traversing.
entities = await self.get_entities(query)
Returns:
tuple: (subgraph, entities, concepts) where subgraph is a list of
(s, p, o) tuples, entities is the seed entity list, and concepts
is the extracted concept list.
"""
entities, concepts = await self.get_entities(query)
if self.verbose:
logger.debug("Getting subgraph...")
@ -271,7 +331,7 @@ class Query:
# Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph)
return list(subgraph), entities, concepts
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
@ -286,11 +346,13 @@ class Query:
Get subgraph with labels resolved for display.
Returns:
tuple: (labeled_edges, uri_map) where:
tuple: (labeled_edges, uri_map, entities, concepts) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
- entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query
"""
subgraph = await self.get_subgraph(query)
subgraph, entities, concepts = await self.get_subgraph(query)
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
@ -338,8 +400,125 @@ class Query:
if self.verbose:
logger.debug("Done.")
return labeled_edges, uri_map
return labeled_edges, uri_map, entities, concepts
async def trace_source_documents(self, edge_uris):
"""
Trace selected edges back to their source documents via provenance.
Follows the chain: edge subgraph (via tg:contains) chunk
page document (via prov:wasDerivedFrom), all in urn:graph:source.
Args:
edge_uris: List of (s, p, o) URI string tuples
Returns:
List of unique document titles
"""
# Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = []
for s, p, o in edge_uris:
quoted = Term(
type=TRIPLE,
triple=SchemaTriple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o),
)
)
subgraph_tasks.append(
self.rag.triples_client.query(
s=None, p=TG_CONTAINS, o=quoted, limit=1,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
)
subgraph_results = await asyncio.gather(
*subgraph_tasks, return_exceptions=True
)
# Collect unique subgraph URIs
subgraph_uris = set()
for result in subgraph_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
subgraph_uris.add(str(triple.s))
if not subgraph_uris:
return []
# Step 2: Walk prov:wasDerivedFrom chain to find documents
# Each level: query ?entity prov:wasDerivedFrom ?parent
# Stop when we find entities typed tg:Document
current_uris = subgraph_uris
doc_uris = set()
for depth in range(4): # Max depth: subgraph → chunk → page → doc
if not current_uris:
break
derivation_tasks = [
self.rag.triples_client.query(
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
for uri in current_uris
]
derivation_results = await asyncio.gather(
*derivation_tasks, return_exceptions=True
)
# URIs with no parent are root documents
next_uris = set()
for uri, result in zip(current_uris, derivation_results):
if isinstance(result, Exception) or not result:
doc_uris.add(uri)
continue
for triple in result:
next_uris.add(str(triple.o))
current_uris = next_uris - doc_uris
if not doc_uris:
return []
# Step 3: Get all document metadata properties
# Skip structural predicates that aren't useful context
SKIP_PREDICATES = {
PROV_WAS_DERIVED_FROM,
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
}
metadata_tasks = [
self.rag.triples_client.query(
s=uri, p=None, o=None, limit=50,
user=self.user, collection=self.collection,
)
for uri in doc_uris
]
metadata_results = await asyncio.gather(
*metadata_tasks, return_exceptions=True
)
doc_edges = []
for result in metadata_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
p = str(triple.p)
if p in SKIP_PREDICATES:
continue
doc_edges.append((
str(triple.s), p, str(triple.o)
))
return doc_edges
class GraphRag:
"""
CRITICAL SECURITY:
@ -371,7 +550,8 @@ class GraphRag:
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, streaming = False, chunk_callback = None,
max_path_length = 2, edge_limit = 25, streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
):
"""
@ -399,6 +579,7 @@ class GraphRag:
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
q_uri = question_uri(session_id)
gnd_uri = make_grounding_uri(session_id)
exp_uri = make_exploration_uri(session_id)
foc_uri = make_focus_uri(session_id)
syn_uri = make_synthesis_uri(session_id)
@ -421,12 +602,23 @@ class GraphRag:
max_path_length = max_path_length,
)
kg, uri_map = await q.get_labelgraph(query)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Emit grounding explain after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
# Emit exploration explain after graph retrieval completes
if explain_callback:
exp_triples = set_graph(
exploration_triples(exp_uri, q_uri, len(kg)),
exploration_triples(
exp_uri, gnd_uri, len(kg),
entities=seed_entities,
),
GRAPH_RETRIEVAL
)
await explain_callback(exp_triples, exp_uri)
@ -453,9 +645,9 @@ class GraphRag:
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1: Edge Selection - LLM selects relevant edges with reasoning
selection_response = await self.prompt_client.prompt(
"kg-edge-selection",
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_response = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
@ -463,52 +655,44 @@ class GraphRag:
)
if self.verbose:
logger.debug(f"Edge selection response: {selection_response}")
logger.debug(f"Edge scoring response: {scoring_response}")
# Parse response to get selected edge IDs and reasoning
# Response can be a string (JSONL) or a list (JSON array)
selected_ids = set()
selected_edges_with_reasoning = [] # For explain
# Parse scoring response to get edge IDs with scores
scored_edges = []
if isinstance(selection_response, list):
# JSON array response
for obj in selection_response:
if isinstance(obj, dict) and "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
elif isinstance(selection_response, str):
# JSONL string response
for line in selection_response.strip().split('\n'):
def parse_scored_edge(obj):
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
except (ValueError, TypeError):
score = 0
scored_edges.append({"id": obj["id"], "score": score})
if isinstance(scoring_response, list):
for obj in scoring_response:
parse_scored_edge(obj)
elif isinstance(scoring_response, str):
for line in scoring_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
if "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
parse_scored_edge(json.loads(line))
except json.JSONDecodeError:
logger.warning(f"Failed to parse edge selection line: {line}")
continue
logger.warning(
f"Failed to parse edge scoring line: {line}"
)
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
selected_ids = {e["id"] for e in top_edges}
if self.verbose:
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}")
logger.debug(
f"Scored {len(scored_edges)} edges, "
f"selected top {len(selected_ids)}"
)
# Filter to selected edges
selected_edges = []
@ -516,6 +700,82 @@ class GraphRag:
if eid in edge_map:
selected_edges.append(edge_map[eid])
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
selected_edges_with_ids = [
{"id": eid, "s": s, "p": p, "o": o}
for eid in selected_ids
if eid in edge_map
for s, p, o in [edge_map[eid]]
]
# Collect selected edge URIs for document tracing
selected_edge_uris = [
uri_map[eid]
for eid in selected_ids
if eid in uri_map
]
# Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
)
# Handle exceptions from gather
if isinstance(reasoning_response, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_response}"
)
reasoning_response = ""
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
)
source_documents = []
if self.verbose:
logger.debug(f"Edge reasoning response: {reasoning_response}")
# Parse reasoning response and build explainability data
reasoning_map = {}
def parse_reasoning(obj):
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
if isinstance(reasoning_response, list):
for obj in reasoning_response:
parse_reasoning(obj)
elif isinstance(reasoning_response, str):
for line in reasoning_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_reasoning(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge reasoning line: {line}"
)
selected_edges_with_reasoning = []
for eid in selected_ids:
if eid in uri_map:
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": reasoning_map.get(eid, ""),
})
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
@ -534,6 +794,18 @@ class GraphRag:
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
# Add source document metadata as knowledge edges
for s, p, o in source_documents:
selected_edge_dicts.append({
"s": s, "p": p, "o": o,
})
synthesis_variables = {
"query": query,
"knowledge": selected_edge_dicts,
}
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
@ -544,10 +816,7 @@ class GraphRag:
await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
},
variables=synthesis_variables,
streaming=True,
chunk_callback=accumulating_callback
)
@ -556,10 +825,7 @@ class GraphRag:
else:
resp = await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
}
variables=synthesis_variables,
)
if self.verbose:
@ -570,9 +836,8 @@ class GraphRag:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian if callback provided
# Save answer to librarian
if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
@ -580,13 +845,11 @@ class GraphRag:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None # Fall back to inline content
synthesis_doc_id = None
# Generate triples with document reference or inline content
syn_triples = set_graph(
synthesis_triples(
syn_uri, foc_uri,
answer_text="" if synthesis_doc_id else answer_text,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL

View file

@ -39,6 +39,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
**params | {
@ -48,6 +49,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_limit": edge_limit,
}
)
@ -55,6 +57,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
@ -292,6 +295,11 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_limit:
edge_limit = v.edge_limit
else:
edge_limit = self.default_edge_limit
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
@ -322,6 +330,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
explain_callback = send_explainability,
@ -335,6 +344,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)