Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding (#697)

Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding,
consistent PROV-O

GraphRAG:
- Split retrieval into 4 prompt stages: extract-concepts,
  kg-edge-scoring,
  kg-edge-reasoning, kg-synthesis (was single-stage)
- Add concept extraction (grounding) for per-concept embedding
- Filter main query to default graph, ignoring
  provenance/explainability edges
- Add source document edges to knowledge graph

DocumentRAG:
- Add grounding step with concept extraction, matching GraphRAG's
  pattern:
  Question → Grounding → Exploration → Synthesis
- Per-concept embedding and chunk retrieval with deduplication

Cross-pipeline:
- Make PROV-O derivation links consistent: wasGeneratedBy for first
  entity from Activity, wasDerivedFrom for entity-to-entity chains
- Update CLIs (tg-invoke-agent, tg-invoke-graph-rag,
  tg-invoke-document-rag) for new explainability structure
- Fix all affected unit and integration tests
This commit is contained in:
cybermaggedon 2026-03-16 12:12:13 +00:00 committed by GitHub
parent 29b4300808
commit a115ec06ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1537 additions and 1008 deletions

View file

@ -15,10 +15,11 @@ from trustgraph.provenance.agent import (
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_AGENT_QUESTION,
)
@ -110,84 +111,107 @@ class TestAgentSessionTriples:
class TestAgentIterationTriples:
ITER_URI = "urn:trustgraph:agent:test-session/i1"
PARENT_URI = "urn:trustgraph:agent:test-session"
SESSION_URI = "urn:trustgraph:agent:test-session"
PREV_URI = "urn:trustgraph:agent:test-session/i0"
def test_iteration_types(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="thinking", action="search", observation="found it",
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
def test_iteration_derived_from_parent(self):
def test_first_iteration_generated_by_question(self):
"""First iteration uses wasGeneratedBy to link to question activity."""
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
# Should NOT have wasDerivedFrom
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is None
def test_subsequent_iteration_derived_from_previous(self):
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
triples = agent_iteration_triples(
self.ITER_URI, previous_uri=self.PREV_URI,
action="search",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
assert derived.o.iri == self.PREV_URI
# Should NOT have wasGeneratedBy
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is None
def test_iteration_label_includes_action(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query",
)
label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
assert label is not None
assert "graph-rag-query" in label.o.value
def test_iteration_thought_inline(self):
def test_iteration_thought_sub_entity(self):
"""Thought is a sub-entity with Reflection and Thought types."""
thought_uri = "urn:trustgraph:agent:test-session/i1/thought"
thought_doc = "urn:doc:thought-1"
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="I need to search for info",
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
thought_uri=thought_uri,
thought_document_id=thought_doc,
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought is not None
assert thought.o.value == "I need to search for info"
# Iteration links to thought sub-entity
thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought_link is not None
assert thought_link.o.iri == thought_uri
# Thought has correct types
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
# Thought was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Thought has document reference
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
assert doc is not None
assert doc.o.iri == thought_doc
def test_iteration_thought_document_preferred(self):
"""When thought_document_id is provided, inline thought is not stored."""
def test_iteration_observation_sub_entity(self):
"""Observation is a sub-entity with Reflection and Observation types."""
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
obs_doc = "urn:doc:obs-1"
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
thought="inline thought",
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
thought_document_id="urn:doc:thought-1",
observation_uri=obs_uri,
observation_document_id=obs_doc,
)
thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI)
assert thought_doc is not None
assert thought_doc.o.iri == "urn:doc:thought-1"
thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought_inline is None
def test_iteration_observation_inline(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
observation="Found 3 results",
)
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs is not None
assert obs.o.value == "Found 3 results"
def test_iteration_observation_document_preferred(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
action="search",
observation="inline obs",
observation_document_id="urn:doc:obs-1",
)
obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI)
assert obs_doc is not None
assert obs_doc.o.iri == "urn:doc:obs-1"
obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_inline is None
# Iteration links to observation sub-entity
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_link is not None
assert obs_link.o.iri == obs_uri
# Observation has correct types
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
# Observation was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Observation has document reference
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
assert doc is not None
assert doc.o.iri == obs_doc
def test_iteration_action_recorded(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query",
)
action = find_triple(triples, TG_ACTION, self.ITER_URI)
@ -197,7 +221,7 @@ class TestAgentIterationTriples:
def test_iteration_arguments_json_encoded(self):
args = {"query": "test query", "limit": 10}
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
arguments=args,
)
@ -208,7 +232,7 @@ class TestAgentIterationTriples:
def test_iteration_default_arguments_empty_dict(self):
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
@ -219,7 +243,7 @@ class TestAgentIterationTriples:
def test_iteration_no_thought_or_observation(self):
"""Minimal iteration with just action — no thought or observation triples."""
triples = agent_iteration_triples(
self.ITER_URI, self.PARENT_URI,
self.ITER_URI, question_uri=self.SESSION_URI,
action="noop",
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
@ -228,19 +252,19 @@ class TestAgentIterationTriples:
assert obs is None
def test_iteration_chaining(self):
"""Second iteration derives from first iteration, not session."""
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
iter1_uri = "urn:trustgraph:agent:sess/i1"
iter2_uri = "urn:trustgraph:agent:sess/i2"
triples1 = agent_iteration_triples(
iter1_uri, self.PARENT_URI, action="step1",
iter1_uri, question_uri=self.SESSION_URI, action="step1",
)
triples2 = agent_iteration_triples(
iter2_uri, iter1_uri, action="step2",
iter2_uri, previous_uri=iter1_uri, action="step2",
)
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
assert derived1.o.iri == self.PARENT_URI
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
assert gen1.o.iri == self.SESSION_URI
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
assert derived2.o.iri == iter1_uri
@ -253,42 +277,50 @@ class TestAgentIterationTriples:
class TestAgentFinalTriples:
FINAL_URI = "urn:trustgraph:agent:test-session/final"
PARENT_URI = "urn:trustgraph:agent:test-session/i3"
PREV_URI = "urn:trustgraph:agent:test-session/i3"
SESSION_URI = "urn:trustgraph:agent:test-session"
def test_final_types(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
self.FINAL_URI, previous_uri=self.PREV_URI,
)
assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE)
def test_final_derived_from_parent(self):
def test_final_derived_from_previous(self):
"""Conclusion with iterations uses wasDerivedFrom."""
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
self.FINAL_URI, previous_uri=self.PREV_URI,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
assert derived.o.iri == self.PREV_URI
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is None
def test_final_generated_by_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasGeneratedBy."""
triples = agent_final_triples(
self.FINAL_URI, question_uri=self.SESSION_URI,
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is None
def test_final_label(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="42"
self.FINAL_URI, previous_uri=self.PREV_URI,
)
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
assert label is not None
assert label.o.value == "Conclusion"
def test_final_inline_answer(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI, answer="The answer is 42"
)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is not None
assert answer.o.value == "The answer is 42"
def test_final_document_reference(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI,
self.FINAL_URI, previous_uri=self.PREV_URI,
document_id="urn:trustgraph:agent:sess/answer",
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
@ -296,29 +328,9 @@ class TestAgentFinalTriples:
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
def test_final_document_takes_precedence(self):
def test_final_no_document(self):
triples = agent_final_triples(
self.FINAL_URI, self.PARENT_URI,
answer="inline",
document_id="urn:doc:123",
self.FINAL_URI, previous_uri=self.PREV_URI,
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is not None
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
assert answer is None
def test_final_no_answer_or_document(self):
triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI)
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert answer is None
assert doc is None
def test_final_derives_from_session_when_no_iterations(self):
"""When agent answers immediately, final derives from session."""
session_uri = "urn:trustgraph:agent:test-session"
triples = agent_final_triples(
self.FINAL_URI, session_uri, answer="direct answer"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived.o.iri == session_uri

View file

@ -10,9 +10,11 @@ from trustgraph.api.explainability import (
EdgeSelection,
ExplainEntity,
Question,
Grounding,
Exploration,
Focus,
Synthesis,
Reflection,
Analysis,
Conclusion,
parse_edge_selection_triples,
@ -20,11 +22,11 @@ from trustgraph.api.explainability import (
wire_triples_to_tuples,
ExplainabilityClient,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
RDF_TYPE, RDFS_LABEL,
@ -71,6 +73,18 @@ class TestExplainEntityFromTriples:
assert isinstance(entity, Question)
assert entity.question_type == "agent"
def test_grounding(self):
triples = [
("urn:gnd:1", RDF_TYPE, TG_GROUNDING),
("urn:gnd:1", TG_CONCEPT, "machine learning"),
("urn:gnd:1", TG_CONCEPT, "neural networks"),
]
entity = ExplainEntity.from_triples("urn:gnd:1", triples)
assert isinstance(entity, Grounding)
assert len(entity.concepts) == 2
assert "machine learning" in entity.concepts
assert "neural networks" in entity.concepts
def test_exploration(self):
triples = [
("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
@ -89,6 +103,17 @@ class TestExplainEntityFromTriples:
assert isinstance(entity, Exploration)
assert entity.chunk_count == 5
def test_exploration_with_entities(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
("urn:exp:3", TG_EDGE_COUNT, "10"),
("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"),
("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"),
]
entity = ExplainEntity.from_triples("urn:exp:3", triples)
assert isinstance(entity, Exploration)
assert len(entity.entities) == 2
def test_exploration_invalid_count(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
@ -110,69 +135,76 @@ class TestExplainEntityFromTriples:
assert "urn:edge:1" in entity.selected_edge_uris
assert "urn:edge:2" in entity.selected_edge_uris
def test_synthesis_with_content(self):
def test_synthesis_with_document(self):
triples = [
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:1", TG_CONTENT, "The answer is 42"),
("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"),
]
entity = ExplainEntity.from_triples("urn:syn:1", triples)
assert isinstance(entity, Synthesis)
assert entity.content == "The answer is 42"
assert entity.document_uri == ""
assert entity.document_uri == "urn:doc:answer-1"
def test_synthesis_with_document(self):
def test_synthesis_no_document(self):
triples = [
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"),
]
entity = ExplainEntity.from_triples("urn:syn:2", triples)
assert isinstance(entity, Synthesis)
assert entity.document_uri == "urn:doc:answer-1"
assert entity.document_uri == ""
def test_reflection_thought(self):
triples = [
("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE),
("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"),
]
entity = ExplainEntity.from_triples("urn:ref:1", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "thought"
assert entity.document_uri == "urn:doc:thought-1"
def test_reflection_observation(self):
triples = [
("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ref:2", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "observation"
assert entity.document_uri == "urn:doc:obs-1"
def test_analysis(self):
triples = [
("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
("urn:ana:1", TG_THOUGHT, "I should search"),
("urn:ana:1", TG_ACTION, "graph-rag-query"),
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
("urn:ana:1", TG_OBSERVATION, "Found results"),
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:1", triples)
assert isinstance(entity, Analysis)
assert entity.thought == "I should search"
assert entity.action == "graph-rag-query"
assert entity.arguments == '{"query": "test"}'
assert entity.observation == "Found results"
def test_analysis_with_document_refs(self):
triples = [
("urn:ana:2", RDF_TYPE, TG_ANALYSIS),
("urn:ana:2", TG_ACTION, "search"),
("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"),
("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:2", triples)
assert isinstance(entity, Analysis)
assert entity.thought_document_uri == "urn:doc:thought-1"
assert entity.observation_document_uri == "urn:doc:obs-1"
def test_conclusion_with_answer(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_ANSWER, "The final answer"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.answer == "The final answer"
assert entity.thought_uri == "urn:ref:thought-1"
assert entity.observation_uri == "urn:ref:obs-1"
def test_conclusion_with_document(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_DOCUMENT, "urn:doc:final"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.document_uri == "urn:doc:final"
def test_conclusion_no_document(self):
triples = [
("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
("urn:conc:2", TG_DOCUMENT, "urn:doc:final"),
]
entity = ExplainEntity.from_triples("urn:conc:2", triples)
assert isinstance(entity, Conclusion)
assert entity.document_uri == "urn:doc:final"
assert entity.document_uri == ""
def test_unknown_type(self):
triples = [
@ -457,25 +489,7 @@ class TestExplainabilityClientResolveLabel:
class TestExplainabilityClientContentFetching:
def test_fetch_synthesis_inline_content(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1", content="inline answer")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == "inline answer"
def test_fetch_synthesis_truncated_content(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
long_content = "x" * 20000
synthesis = Synthesis(uri="urn:syn:1", content=long_content)
result = client.fetch_synthesis_content(synthesis, api=None, max_content=100)
assert len(result) < 20000
assert result.endswith("... [truncated]")
def test_fetch_synthesis_from_librarian(self):
def test_fetch_document_content_from_librarian(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
@ -483,66 +497,32 @@ class TestExplainabilityClientContentFetching:
mock_library.get_document_content.return_value = b"librarian content"
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(
uri="urn:syn:1",
document_uri="urn:document:abc123"
result = client.fetch_document_content(
"urn:document:abc123", api=mock_api
)
result = client.fetch_synthesis_content(synthesis, api=mock_api)
assert result == "librarian content"
def test_fetch_synthesis_no_content_or_document(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
synthesis = Synthesis(uri="urn:syn:1")
result = client.fetch_synthesis_content(synthesis, api=None)
assert result == ""
def test_fetch_conclusion_inline(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
conclusion = Conclusion(uri="urn:conc:1", answer="42")
result = client.fetch_conclusion_content(conclusion, api=None)
assert result == "42"
def test_fetch_analysis_content_from_librarian(self):
def test_fetch_document_content_truncated(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
mock_api.library.return_value = mock_library
mock_library.get_document_content.side_effect = [
b"thought content",
b"observation content",
]
mock_library.get_document_content.return_value = b"x" * 20000
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis(
uri="urn:ana:1",
action="search",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
result = client.fetch_document_content(
"urn:doc:1", api=mock_api, max_content=100
)
client.fetch_analysis_content(analysis, api=mock_api)
assert analysis.thought == "thought content"
assert analysis.observation == "observation content"
assert len(result) < 20000
assert result.endswith("... [truncated]")
def test_fetch_analysis_skips_when_inline_exists(self):
def test_fetch_document_content_empty_uri(self):
mock_flow = MagicMock()
mock_api = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
analysis = Analysis(
uri="urn:ana:1",
action="search",
thought="already have thought",
observation="already have observation",
thought_document_uri="urn:doc:thought",
observation_document_uri="urn:doc:obs",
)
client.fetch_analysis_content(analysis, api=mock_api)
# Should not call librarian since inline content exists
mock_api.library.assert_not_called()
result = client.fetch_document_content("", api=mock_api)
assert result == ""
class TestExplainabilityClientDetectSessionType:

View file

@ -13,6 +13,7 @@ from trustgraph.provenance.triples import (
derived_entity_triples,
subgraph_provenance_triples,
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
@ -32,10 +33,12 @@ from trustgraph.provenance.namespaces import (
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_CONTENT, TG_DOCUMENT,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
GRAPH_SOURCE, GRAPH_RETRIEVAL,
)
@ -530,36 +533,77 @@ class TestQuestionTriples:
assert len(triples) == 6
class TestExplorationTriples:
class TestGroundingTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
GND_URI = "urn:trustgraph:prov:grounding:test-session"
Q_URI = "urn:trustgraph:question:test-session"
def test_exploration_types(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_grounding_types(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"])
assert has_type(triples, self.GND_URI, PROV_ENTITY)
assert has_type(triples, self.GND_URI, TG_GROUNDING)
def test_exploration_generated_by_question(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
def test_grounding_generated_by_question(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
assert gen is not None
assert gen.o.iri == self.Q_URI
def test_grounding_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 3
values = {t.o.value for t in concepts}
assert values == {"AI", "ML", "robots"}
def test_grounding_empty_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 0
def test_grounding_label(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
label = find_triple(triples, RDFS_LABEL, self.GND_URI)
assert label is not None
assert label.o.value == "Grounding"
class TestExplorationTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
GND_URI = "urn:trustgraph:prov:grounding:test-session"
def test_exploration_types(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_exploration_derived_from_grounding(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert derived is not None
assert derived.o.iri == self.GND_URI
def test_exploration_edge_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "15"
def test_exploration_zero_edges(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 0)
triples = exploration_triples(self.EXP_URI, self.GND_URI, 0)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "0"
def test_exploration_with_entities(self):
entities = ["urn:e:machine-learning", "urn:e:neural-networks"]
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities)
ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI)
assert len(ent_triples) == 2
def test_exploration_triple_count(self):
triples = exploration_triples(self.EXP_URI, self.Q_URI, 10)
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10)
assert len(triples) == 5
@ -652,6 +696,7 @@ class TestSynthesisTriples:
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
def test_synthesis_derived_from_focus(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
@ -659,12 +704,6 @@ class TestSynthesisTriples:
assert derived is not None
assert derived.o.iri == self.FOC_URI
def test_synthesis_with_inline_content(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42")
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is not None
assert content.o.value == "The answer is 42"
def test_synthesis_with_document_reference(self):
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
@ -675,23 +714,9 @@ class TestSynthesisTriples:
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:question:abc/answer"
def test_synthesis_document_takes_precedence(self):
"""When both document_id and answer_text are provided, document_id wins."""
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
answer_text="inline",
document_id="urn:doc:123",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is not None
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None
def test_synthesis_no_content_or_document(self):
def test_synthesis_no_document(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert content is None
assert doc is None
@ -723,31 +748,31 @@ class TestDocRagQuestionTriples:
class TestDocRagExplorationTriples:
EXP_URI = "urn:trustgraph:docrag:test/exploration"
Q_URI = "urn:trustgraph:docrag:test"
GND_URI = "urn:trustgraph:docrag:test/grounding"
def test_docrag_exploration_types(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_docrag_exploration_generated_by(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
assert gen.o.iri == self.Q_URI
def test_docrag_exploration_derived_from_grounding(self):
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert derived.o.iri == self.GND_URI
def test_docrag_exploration_chunk_count(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7)
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
assert cc.o.value == "7"
def test_docrag_exploration_without_chunk_ids(self):
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3)
chunks = find_triples(triples, TG_SELECTED_CHUNK)
assert len(chunks) == 0
def test_docrag_exploration_with_chunk_ids(self):
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids)
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids)
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
assert len(chunks) == 3
chunk_uris = {t.o.iri for t in chunks}
@ -770,10 +795,9 @@ class TestDocRagSynthesisTriples:
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
assert derived.o.iri == self.EXP_URI
def test_docrag_synthesis_with_inline(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer")
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content.o.value == "answer"
def test_docrag_synthesis_has_answer_type(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
def test_docrag_synthesis_with_document(self):
triples = docrag_synthesis_triples(
@ -781,5 +805,8 @@ class TestDocRagSynthesisTriples:
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc.o.iri == "urn:doc:ans"
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
assert content is None
def test_docrag_synthesis_no_document(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is None

View file

@ -125,19 +125,15 @@ class TestQuery:
assert query.doc_limit == 50
@pytest.mark.asyncio
async def test_get_vector_method(self):
"""Test Query.get_vector method calls embeddings client correctly"""
# Create mock DocumentRag with embeddings client
async def test_extract_concepts(self):
"""Test Query.extract_concepts extracts concepts from query"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock the embed method to return test vectors in batch format
# New format: [[[vectors_for_text1]]] - returns first text's vector set
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -145,20 +141,62 @@ class TestQuery:
verbose=False
)
# Call get_vector
test_query = "What documents are relevant?"
result = await query.get_vector(test_query)
result = await query.extract_concepts("What is machine learning?")
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "artificial intelligence", "data patterns"]
# Verify result matches expected vectors (extracted from batch)
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test Query.extract_concepts falls back to raw query when no concepts extracted"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is ML?")
assert result == ["What is ML?"]
@pytest.mark.asyncio
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method - returns vectors for each concept
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
concepts = ["machine learning", "data patterns"]
result = await query.get_vectors(concepts)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_docs_method(self):
"""Test Query.get_docs method retrieves documents correctly"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
@ -170,10 +208,8 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock the embedding and document query responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock embeddings - one vector per concept
mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]]
# Mock document embeddings returns ChunkMatch objects
mock_match1 = MagicMock()
@ -184,7 +220,6 @@ class TestQuery:
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -193,16 +228,16 @@ class TestQuery:
doc_limit=15
)
# Call get_docs
test_query = "Find relevant documents"
result = await query.get_docs(test_query)
# Call get_docs with concepts list
concepts = ["test concept"]
result = await query.get_docs(concepts)
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify embeddings client was called with concepts
mock_embeddings_client.embed.assert_called_once_with(concepts)
# Verify doc embeddings client was called correctly (with extracted vector)
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
vector=[0.1, 0.2, 0.3],
limit=15,
user="test_user",
collection="test_collection"
@ -218,14 +253,17 @@ class TestQuery:
@pytest.mark.asyncio
async def test_document_rag_query_method(self, mock_fetch_chunk):
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
@ -234,11 +272,9 @@ class TestQuery:
mock_match2.score = 0.8
expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -247,7 +283,6 @@ class TestQuery:
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query(
query="test query",
user="test_user",
@ -255,12 +290,18 @@ class TestQuery:
doc_limit=10
)
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify concept extraction was called
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "test query"}
)
# Verify doc embeddings client was called (with extracted vector)
# Verify embeddings called with extracted concepts
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
@ -270,23 +311,23 @@ class TestQuery:
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "test query"
# Documents should be fetched content, not chunk_ids
docs = call_args.kwargs["documents"]
assert "Relevant document content" in docs
assert "Another document" in docs
# Verify result
assert result == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag.query method with default parameters"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format)
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
@ -294,7 +335,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -302,10 +342,9 @@ class TestQuery:
fetch_chunk=mock_fetch_chunk
)
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query")
# Verify default parameters were used (vector extracted from batch)
# Verify default parameters were used
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
@ -318,7 +357,6 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
"""Test Query.get_docs method with verbose logging"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
@ -330,14 +368,13 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock responses (batch format)
# Mock responses - one vector per concept
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
user="test_user",
@ -346,14 +383,12 @@ class TestQuery:
doc_limit=5
)
# Call get_docs
result = await query.get_docs("verbose test")
# Call get_docs with concepts
result = await query.get_docs(["verbose test"])
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify result is tuple of (docs, chunk_ids) with fetched content
docs, chunk_ids = result
assert "Verbose test doc" in docs
assert "doc/c6" in chunk_ids
@ -361,12 +396,14 @@ class TestQuery:
@pytest.mark.asyncio
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag.query method with verbose logging enabled"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format)
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
@ -374,7 +411,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -383,14 +419,11 @@ class TestQuery:
verbose=True
)
# Call DocumentRag.query
result = await document_rag.query("verbose query test")
# Verify all clients were called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
# Verify prompt client was called with fetched content
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
@ -400,23 +433,20 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
"""Test Query.get_docs method when no documents are found"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk (won't be called if no chunk_ids)
async def mock_fetch(chunk_id, user):
return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty chunk_id list (batch format)
# Mock responses - empty results
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_doc_embeddings_client.query.return_value = []
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -424,30 +454,27 @@ class TestQuery:
verbose=False
)
# Call get_docs
result = await query.get_docs("query with no results")
result = await query.get_docs(["query with no results"])
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify empty result is returned (tuple of empty lists)
assert result == ([], [])
@pytest.mark.asyncio
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
"""Test DocumentRag.query method when no documents are retrieved"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses - no chunk_ids found (batch format)
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -456,10 +483,8 @@ class TestQuery:
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query("query with no matching docs")
# Verify prompt client was called with empty document list
mock_prompt_client.document_prompt.assert_called_once_with(
query="query with no matching docs",
documents=[]
@ -468,18 +493,15 @@ class TestQuery:
assert result == "No documents found response"
@pytest.mark.asyncio
async def test_get_vector_with_verbose(self):
"""Test Query.get_vector method with verbose logging"""
# Create mock DocumentRag with embeddings client
async def test_get_vectors_with_verbose(self):
"""Test Query.get_vectors method with verbose logging"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method (batch format)
expected_vectors = [[0.9, 1.0, 1.1]]
mock_embeddings_client.embed.return_value = [expected_vectors]
mock_embeddings_client.embed.return_value = expected_vectors
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
user="test_user",
@ -487,40 +509,40 @@ class TestQuery:
verbose=True
)
# Call get_vector
result = await query.get_vector("verbose vector test")
result = await query.get_vectors(["verbose vector test"])
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
# Verify result (extracted from batch)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
"""Test complete DocumentRag integration with realistic data flow"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock realistic responses (batch format)
query_text = "What is machine learning?"
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors]
mock_matches = []
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock()
mock_match.chunk_id = chunk_id
mock_match.score = 0.9
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
mock_embeddings_client.embed.return_value = query_vectors
# Each concept query returns some matches
mock_matches_1 = [
MagicMock(chunk_id="doc/ml1", score=0.9),
MagicMock(chunk_id="doc/ml2", score=0.85),
]
mock_matches_2 = [
MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -529,7 +551,6 @@ class TestQuery:
verbose=False
)
# Execute full pipeline
result = await document_rag.query(
query=query_text,
user="research_user",
@ -537,26 +558,69 @@ class TestQuery:
doc_limit=25
)
# Verify complete pipeline execution (now expects list)
mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with(
vector=query_vectors,
limit=25,
user="research_user",
collection="ml_knowledge"
# Verify concept extraction
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": query_text}
)
# Verify embeddings called with concepts
mock_embeddings_client.embed.assert_called_once_with(
["machine learning", "artificial intelligence"]
)
# Verify two per-concept queries were made (25 // 2 = 12 per concept)
assert mock_doc_embeddings_client.query.call_count == 2
# Verify prompt client was called with fetched document content
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == query_text
# Verify documents were fetched from chunk_ids
# Verify documents were fetched and deduplicated
docs = call_args.kwargs["documents"]
assert "Machine learning is a subset of artificial intelligence..." in docs
assert "ML algorithms learn patterns from data to make predictions..." in docs
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
# Verify final result
assert result == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):
"""Test that get_docs deduplicates chunks across multiple concepts"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Two concepts → two vectors
mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Both queries return overlapping chunks
match_a = MagicMock(chunk_id="doc/c1", score=0.9)
match_b = MagicMock(chunk_id="doc/c2", score=0.8)
match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate
mock_doc_embeddings_client.query.side_effect = [
[match_a, match_b],
[match_c],
]
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
doc_limit=10
)
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
assert len(chunk_ids) == 2 # doc/c1 only counted once
assert "doc/c1" in chunk_ids
assert "doc/c2" in chunk_ids

View file

@ -19,7 +19,7 @@ class TestGraphRag:
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
@ -27,7 +27,7 @@ class TestGraphRag:
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
@ -45,7 +45,7 @@ class TestGraphRag:
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag with verbose=True
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
@ -54,7 +54,7 @@ class TestGraphRag:
triples_client=mock_triples_client,
verbose=True
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
@ -73,7 +73,7 @@ class TestQuery:
"""Test Query initialization with default parameters"""
# Create mock GraphRag
mock_rag = MagicMock()
# Initialize Query with defaults
query = Query(
rag=mock_rag,
@ -81,7 +81,7 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
@ -96,7 +96,7 @@ class TestQuery:
"""Test Query initialization with custom parameters"""
# Create mock GraphRag
mock_rag = MagicMock()
# Initialize Query with custom parameters
query = Query(
rag=mock_rag,
@ -108,7 +108,7 @@ class TestQuery:
max_subgraph_size=2000,
max_path_length=3
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
@ -120,18 +120,16 @@ class TestQuery:
assert query.max_path_length == 3
@pytest.mark.asyncio
async def test_get_vector_method(self):
"""Test Query.get_vector method calls embeddings client correctly"""
# Create mock GraphRag with embeddings client
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method to return test vectors (batch format)
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query
# Mock embed to return vectors for a list of concepts
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
@ -139,29 +137,22 @@ class TestQuery:
verbose=False
)
# Call get_vector
test_query = "What is the capital of France?"
result = await query.get_vector(test_query)
concepts = ["machine learning", "neural networks"]
result = await query.get_vectors(concepts)
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_vector_method_with_verbose(self):
"""Test Query.get_vector method with verbose output"""
# Create mock GraphRag with embeddings client
async def test_get_vectors_method_with_verbose(self):
"""Test Query.get_vectors method with verbose output"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method (batch format)
expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Initialize Query with verbose=True
expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
@ -169,48 +160,87 @@ class TestQuery:
verbose=True
)
# Call get_vector
test_query = "Test query for embeddings"
result = await query.get_vector(test_query)
result = await query.get_vectors(["test concept"])
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify result matches expected vectors (extracted from batch)
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_entities_method(self):
"""Test Query.get_entities method retrieves entities correctly"""
# Create mock GraphRag with clients
async def test_extract_concepts(self):
"""Test Query.extract_concepts parses LLM response into concept list"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is machine learning?")
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "neural networks"]
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test extract_concepts falls back to raw query when LLM returns empty"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("test query")
assert result == ["test query"]
@pytest.mark.asyncio
async def test_get_entities_method(self):
"""Test Query.get_entities extracts concepts, embeds, and retrieves entities"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# Mock the embedding and entity query responses (batch format)
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock EntityMatch objects with entity as Term-like object
# extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = ""
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
# Mock entity matches
mock_entity1 = MagicMock()
mock_entity1.type = "i" # IRI type
mock_entity1.type = "i"
mock_entity1.iri = "entity1"
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock()
mock_entity2.type = "i" # IRI type
mock_entity2.type = "i"
mock_entity2.iri = "entity2"
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -219,35 +249,23 @@ class TestQuery:
entity_limit=25
)
# Call get_entities
test_query = "Find related entities"
result = await query.get_entities(test_query)
entities, concepts = await query.get_entities("Find related entities")
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify embeddings client was called with the fallback concept
mock_embeddings_client.embed.assert_called_once_with(["Find related entities"])
# Verify graph embeddings client was called correctly (with extracted vector)
mock_graph_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
limit=25,
user="test_user",
collection="test_collection"
)
# Verify result is list of entity strings
assert result == ["entity1", "entity2"]
# Verify result
assert entities == ["entity1", "entity2"]
assert concepts == ["Find related entities"]
@pytest.mark.asyncio
async def test_maybe_label_with_cached_label(self):
"""Test Query.maybe_label method with cached label"""
# Create mock GraphRag with label cache
mock_rag = MagicMock()
# Create mock LRUCacheWithTTL
mock_cache = MagicMock()
mock_cache.get.return_value = "Entity One Label"
mock_rag.label_cache = mock_cache
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -255,32 +273,25 @@ class TestQuery:
verbose=False
)
# Call maybe_label with cached entity
result = await query.maybe_label("entity1")
# Verify cached label is returned
assert result == "Entity One Label"
# Verify cache was checked with proper key format (user:collection:entity)
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
"""Test Query.maybe_label method with database label lookup"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock triple result with label
mock_triple = MagicMock()
mock_triple.o = "Human Readable Label"
mock_triples_client.query.return_value = [mock_triple]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -288,20 +299,18 @@ class TestQuery:
verbose=False
)
# Call maybe_label
result = await query.maybe_label("http://example.com/entity")
# Verify triples client was called correctly
mock_triples_client.query.assert_called_once_with(
s="http://example.com/entity",
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection"
collection="test_collection",
g=""
)
# Verify result and cache update with proper key
assert result == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@ -309,40 +318,34 @@ class TestQuery:
@pytest.mark.asyncio
async def test_maybe_label_with_no_label_found(self):
"""Test Query.maybe_label method when no label is found"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock empty result (no label found)
mock_triples_client.query.return_value = []
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Call maybe_label
result = await query.maybe_label("unlabeled_entity")
# Verify triples client was called
mock_triples_client.query.assert_called_once_with(
s="unlabeled_entity",
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
user="test_user",
collection="test_collection"
collection="test_collection",
g=""
)
# Verify result is entity itself and cache is updated
assert result == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@ -350,29 +353,25 @@ class TestQuery:
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock triple results for different query patterns
mock_triple1 = MagicMock()
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
mock_triple2 = MagicMock()
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent, p=None, o=None
[mock_triple2], # s=None, p=ent, o=None
[mock_triple3], # s=None, p=None, o=ent
[mock_triple1], # s=ent
[mock_triple2], # p=ent
[mock_triple3], # o=ent
]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -380,29 +379,25 @@ class TestQuery:
verbose=False,
triple_limit=10
)
# Call follow_edges
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify all three query patterns were called
assert mock_triples_client.query_stream.call_count == 3
# Verify query_stream calls
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20
user="test_user", collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection", batch_size=20
user="test_user", collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection", batch_size=20
user="test_user", collection="test_collection", batch_size=20, g=""
)
# Verify subgraph contains discovered triples
expected_subgraph = {
("entity1", "predicate1", "object1"),
("subject2", "entity1", "object2"),
@ -413,38 +408,30 @@ class TestQuery:
@pytest.mark.asyncio
async def test_follow_edges_with_path_length_zero(self):
"""Test Query.follow_edges method with path_length=0"""
# Create mock GraphRag
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Call follow_edges with path_length=0
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
# Verify no queries were made
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph remains empty
assert subgraph == set()
@pytest.mark.asyncio
async def test_follow_edges_with_max_subgraph_size_limit(self):
"""Test Query.follow_edges method respects max_subgraph_size"""
# Create mock GraphRag
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Initialize Query with small max_subgraph_size
query = Query(
rag=mock_rag,
user="test_user",
@ -452,23 +439,17 @@ class TestQuery:
verbose=False,
max_subgraph_size=2
)
# Pre-populate subgraph to exceed limit
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
# Call follow_edges
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify no queries were made due to size limit
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph unchanged
assert len(subgraph) == 3
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
# Create mock Query that patches get_entities and follow_edges_batch
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
mock_rag = MagicMock()
query = Query(
@ -479,130 +460,119 @@ class TestQuery:
max_path_length=1
)
# Mock get_entities to return test entities
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
# Mock get_entities to return (entities, concepts) tuple
query.get_entities = AsyncMock(
return_value=(["entity1", "entity2"], ["concept1"])
)
# Mock follow_edges_batch to return test triples
query.follow_edges_batch = AsyncMock(return_value={
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
})
# Call get_subgraph
result = await query.get_subgraph("test query")
subgraph, entities, concepts = await query.get_subgraph("test query")
# Verify get_entities was called
query.get_entities.assert_called_once_with("test query")
# Verify follow_edges_batch was called with entities and max_path_length
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
# Verify result is list format and contains expected triples
assert isinstance(result, list)
assert len(result) == 2
assert ("entity1", "predicate1", "object1") in result
assert ("entity2", "predicate2", "object2") in result
assert isinstance(subgraph, list)
assert len(subgraph) == 2
assert ("entity1", "predicate1", "object1") in subgraph
assert ("entity2", "predicate2", "object2") in subgraph
assert entities == ["entity1", "entity2"]
assert concepts == ["concept1"]
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):
"""Test Query.get_labelgraph method converts entities to labels"""
# Create mock Query
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
collection="test_collection",
verbose=False,
max_subgraph_size=100
)
# Mock get_subgraph to return test triples
test_subgraph = [
("entity1", "predicate1", "object1"),
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
("entity3", "predicate3", "object3")
]
query.get_subgraph = AsyncMock(return_value=test_subgraph)
# Mock maybe_label to return human-readable labels
test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
query.get_subgraph = AsyncMock(
return_value=(test_subgraph, test_entities, test_concepts)
)
async def mock_maybe_label(entity):
label_map = {
"entity1": "Human Entity One",
"predicate1": "Human Predicate One",
"predicate1": "Human Predicate One",
"object1": "Human Object One",
"entity3": "Human Entity Three",
"predicate3": "Human Predicate Three",
"object3": "Human Object Three"
}
return label_map.get(entity, entity)
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
# Call get_labelgraph
labeled_edges, uri_map = await query.get_labelgraph("test query")
# Verify get_subgraph was called
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
query.get_subgraph.assert_called_once_with("test query")
# Verify label triples are filtered out
assert len(labeled_edges) == 2 # Label triple should be excluded
# Label triples filtered out
assert len(labeled_edges) == 2
# Verify maybe_label was called for non-label triples
expected_calls = [
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
]
# maybe_label called for non-label triples
assert query.maybe_label.call_count == 6
# Verify result contains human-readable labels
expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three")
]
assert labeled_edges == expected_edges
# Verify uri_map maps labeled edges back to original URIs
assert len(uri_map) == 2
assert entities == test_entities
assert concepts == test_concepts
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance"""
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
# Mock prompt client responses for two-step process
expected_response = "This is the RAG response"
test_labelgraph = [("Subject", "Predicate", "Object")]
# Compute the edge ID for the test edge
test_edge_id = edge_id("Subject", "Predicate", "Object")
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
test_entities = ["http://example.org/subject"]
test_concepts = ["test concept"]
# Mock edge selection response (JSONL format)
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
# Configure prompt mock to return different responses based on prompt name
# Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return edge_selection_response
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9})
elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
elif prompt_name == "kg-synthesis":
return expected_response
return ""
mock_prompt_client.prompt = mock_prompt
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -611,27 +581,20 @@ class TestQuery:
verbose=False
)
# We need to patch the Query class's get_labelgraph method
original_query_init = Query.__init__
# Patch Query.get_labelgraph to return test data
original_get_labelgraph = Query.get_labelgraph
def mock_query_init(self, *args, **kwargs):
original_query_init(self, *args, **kwargs)
async def mock_get_labelgraph(self, query_text):
return test_labelgraph, test_uri_map
return test_labelgraph, test_uri_map, test_entities, test_concepts
Query.__init__ = mock_query_init
Query.get_labelgraph = mock_get_labelgraph
# Collect provenance emitted via callback
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
try:
# Call GraphRag.query with provenance callback
response = await graph_rag.query(
query="test query",
user="test_user",
@ -641,25 +604,22 @@ class TestQuery:
explain_callback=collect_provenance
)
# Verify response text
assert response == expected_response
# Verify provenance was emitted incrementally (4 events: question, exploration, focus, synthesis)
assert len(provenance_events) == 4
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5
# Verify each event has triples and a URN
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
# Verify order: question, exploration, focus, synthesis
# Verify order
assert "question" in provenance_events[0][1]
assert "exploration" in provenance_events[1][1]
assert "focus" in provenance_events[2][1]
assert "synthesis" in provenance_events[3][1]
assert "grounding" in provenance_events[1][1]
assert "exploration" in provenance_events[2][1]
assert "focus" in provenance_events[3][1]
assert "synthesis" in provenance_events[4][1]
finally:
# Restore original methods
Query.__init__ = original_query_init
Query.get_labelgraph = original_get_labelgraph
Query.get_labelgraph = original_get_labelgraph