mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding (#697)
Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding, consistent PROV-O GraphRAG: - Split retrieval into 4 prompt stages: extract-concepts, kg-edge-scoring, kg-edge-reasoning, kg-synthesis (was single-stage) - Add concept extraction (grounding) for per-concept embedding - Filter main query to default graph, ignoring provenance/explainability edges - Add source document edges to knowledge graph DocumentRAG: - Add grounding step with concept extraction, matching GraphRAG's pattern: Question → Grounding → Exploration → Synthesis - Per-concept embedding and chunk retrieval with deduplication Cross-pipeline: - Make PROV-O derivation links consistent: wasGeneratedBy for first entity from Activity, wasDerivedFrom for entity-to-entity chains - Update CLIs (tg-invoke-agent, tg-invoke-graph-rag, tg-invoke-document-rag) for new explainability structure - Fix all affected unit and integration tests
This commit is contained in:
parent
29b4300808
commit
a115ec06ab
25 changed files with 1537 additions and 1008 deletions
|
|
@ -15,10 +15,11 @@ from trustgraph.provenance.agent import (
|
|||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
|
@ -110,84 +111,107 @@ class TestAgentSessionTriples:
|
|||
class TestAgentIterationTriples:
|
||||
|
||||
ITER_URI = "urn:trustgraph:agent:test-session/i1"
|
||||
PARENT_URI = "urn:trustgraph:agent:test-session"
|
||||
SESSION_URI = "urn:trustgraph:agent:test-session"
|
||||
PREV_URI = "urn:trustgraph:agent:test-session/i0"
|
||||
|
||||
def test_iteration_types(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
thought="thinking", action="search", observation="found it",
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
|
||||
|
||||
def test_iteration_derived_from_parent(self):
|
||||
def test_first_iteration_generated_by_question(self):
|
||||
"""First iteration uses wasGeneratedBy to link to question activity."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
# Should NOT have wasDerivedFrom
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is None
|
||||
|
||||
def test_subsequent_iteration_derived_from_previous(self):
|
||||
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, previous_uri=self.PREV_URI,
|
||||
action="search",
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PARENT_URI
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
# Should NOT have wasGeneratedBy
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_iteration_label_includes_action(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="graph-rag-query",
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
|
||||
assert label is not None
|
||||
assert "graph-rag-query" in label.o.value
|
||||
|
||||
def test_iteration_thought_inline(self):
|
||||
def test_iteration_thought_sub_entity(self):
|
||||
"""Thought is a sub-entity with Reflection and Thought types."""
|
||||
thought_uri = "urn:trustgraph:agent:test-session/i1/thought"
|
||||
thought_doc = "urn:doc:thought-1"
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
thought="I need to search for info",
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
thought_uri=thought_uri,
|
||||
thought_document_id=thought_doc,
|
||||
)
|
||||
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
assert thought is not None
|
||||
assert thought.o.value == "I need to search for info"
|
||||
# Iteration links to thought sub-entity
|
||||
thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
assert thought_link is not None
|
||||
assert thought_link.o.iri == thought_uri
|
||||
# Thought has correct types
|
||||
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
|
||||
# Thought was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Thought has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == thought_doc
|
||||
|
||||
def test_iteration_thought_document_preferred(self):
|
||||
"""When thought_document_id is provided, inline thought is not stored."""
|
||||
def test_iteration_observation_sub_entity(self):
|
||||
"""Observation is a sub-entity with Reflection and Observation types."""
|
||||
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
|
||||
obs_doc = "urn:doc:obs-1"
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
thought="inline thought",
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
thought_document_id="urn:doc:thought-1",
|
||||
observation_uri=obs_uri,
|
||||
observation_document_id=obs_doc,
|
||||
)
|
||||
thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI)
|
||||
assert thought_doc is not None
|
||||
assert thought_doc.o.iri == "urn:doc:thought-1"
|
||||
thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
assert thought_inline is None
|
||||
|
||||
def test_iteration_observation_inline(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
action="search",
|
||||
observation="Found 3 results",
|
||||
)
|
||||
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs is not None
|
||||
assert obs.o.value == "Found 3 results"
|
||||
|
||||
def test_iteration_observation_document_preferred(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
action="search",
|
||||
observation="inline obs",
|
||||
observation_document_id="urn:doc:obs-1",
|
||||
)
|
||||
obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI)
|
||||
assert obs_doc is not None
|
||||
assert obs_doc.o.iri == "urn:doc:obs-1"
|
||||
obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs_inline is None
|
||||
# Iteration links to observation sub-entity
|
||||
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs_link is not None
|
||||
assert obs_link.o.iri == obs_uri
|
||||
# Observation has correct types
|
||||
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
|
||||
# Observation was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Observation has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == obs_doc
|
||||
|
||||
def test_iteration_action_recorded(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="graph-rag-query",
|
||||
)
|
||||
action = find_triple(triples, TG_ACTION, self.ITER_URI)
|
||||
|
|
@ -197,7 +221,7 @@ class TestAgentIterationTriples:
|
|||
def test_iteration_arguments_json_encoded(self):
|
||||
args = {"query": "test query", "limit": 10}
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
arguments=args,
|
||||
)
|
||||
|
|
@ -208,7 +232,7 @@ class TestAgentIterationTriples:
|
|||
|
||||
def test_iteration_default_arguments_empty_dict(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
|
||||
|
|
@ -219,7 +243,7 @@ class TestAgentIterationTriples:
|
|||
def test_iteration_no_thought_or_observation(self):
|
||||
"""Minimal iteration with just action — no thought or observation triples."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="noop",
|
||||
)
|
||||
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
|
|
@ -228,19 +252,19 @@ class TestAgentIterationTriples:
|
|||
assert obs is None
|
||||
|
||||
def test_iteration_chaining(self):
|
||||
"""Second iteration derives from first iteration, not session."""
|
||||
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
|
||||
iter1_uri = "urn:trustgraph:agent:sess/i1"
|
||||
iter2_uri = "urn:trustgraph:agent:sess/i2"
|
||||
|
||||
triples1 = agent_iteration_triples(
|
||||
iter1_uri, self.PARENT_URI, action="step1",
|
||||
iter1_uri, question_uri=self.SESSION_URI, action="step1",
|
||||
)
|
||||
triples2 = agent_iteration_triples(
|
||||
iter2_uri, iter1_uri, action="step2",
|
||||
iter2_uri, previous_uri=iter1_uri, action="step2",
|
||||
)
|
||||
|
||||
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
|
||||
assert derived1.o.iri == self.PARENT_URI
|
||||
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
|
||||
assert gen1.o.iri == self.SESSION_URI
|
||||
|
||||
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
|
||||
assert derived2.o.iri == iter1_uri
|
||||
|
|
@ -253,42 +277,50 @@ class TestAgentIterationTriples:
|
|||
class TestAgentFinalTriples:
|
||||
|
||||
FINAL_URI = "urn:trustgraph:agent:test-session/final"
|
||||
PARENT_URI = "urn:trustgraph:agent:test-session/i3"
|
||||
PREV_URI = "urn:trustgraph:agent:test-session/i3"
|
||||
SESSION_URI = "urn:trustgraph:agent:test-session"
|
||||
|
||||
def test_final_types(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="42"
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
|
||||
assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_final_derived_from_parent(self):
|
||||
def test_final_derived_from_previous(self):
|
||||
"""Conclusion with iterations uses wasDerivedFrom."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="42"
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PARENT_URI
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_final_generated_by_question_when_no_iterations(self):
|
||||
"""When agent answers immediately, final uses wasGeneratedBy."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, question_uri=self.SESSION_URI,
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is None
|
||||
|
||||
def test_final_label(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="42"
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Conclusion"
|
||||
|
||||
def test_final_inline_answer(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="The answer is 42"
|
||||
)
|
||||
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
|
||||
assert answer is not None
|
||||
assert answer.o.value == "The answer is 42"
|
||||
|
||||
def test_final_document_reference(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI,
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
document_id="urn:trustgraph:agent:sess/answer",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
|
|
@ -296,29 +328,9 @@ class TestAgentFinalTriples:
|
|||
assert doc.o.type == IRI
|
||||
assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
|
||||
|
||||
def test_final_document_takes_precedence(self):
|
||||
def test_final_no_document(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI,
|
||||
answer="inline",
|
||||
document_id="urn:doc:123",
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert doc is not None
|
||||
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
|
||||
assert answer is None
|
||||
|
||||
def test_final_no_answer_or_document(self):
|
||||
triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI)
|
||||
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert answer is None
|
||||
assert doc is None
|
||||
|
||||
def test_final_derives_from_session_when_no_iterations(self):
|
||||
"""When agent answers immediately, final derives from session."""
|
||||
session_uri = "urn:trustgraph:agent:test-session"
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, session_uri, answer="direct answer"
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived.o.iri == session_uri
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ from trustgraph.api.explainability import (
|
|||
EdgeSelection,
|
||||
ExplainEntity,
|
||||
Question,
|
||||
Grounding,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
Reflection,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
parse_edge_selection_triples,
|
||||
|
|
@ -20,11 +22,11 @@ from trustgraph.api.explainability import (
|
|||
wire_triples_to_tuples,
|
||||
ExplainabilityClient,
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
|
||||
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
|
|
@ -71,6 +73,18 @@ class TestExplainEntityFromTriples:
|
|||
assert isinstance(entity, Question)
|
||||
assert entity.question_type == "agent"
|
||||
|
||||
def test_grounding(self):
|
||||
triples = [
|
||||
("urn:gnd:1", RDF_TYPE, TG_GROUNDING),
|
||||
("urn:gnd:1", TG_CONCEPT, "machine learning"),
|
||||
("urn:gnd:1", TG_CONCEPT, "neural networks"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:gnd:1", triples)
|
||||
assert isinstance(entity, Grounding)
|
||||
assert len(entity.concepts) == 2
|
||||
assert "machine learning" in entity.concepts
|
||||
assert "neural networks" in entity.concepts
|
||||
|
||||
def test_exploration(self):
|
||||
triples = [
|
||||
("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
|
||||
|
|
@ -89,6 +103,17 @@ class TestExplainEntityFromTriples:
|
|||
assert isinstance(entity, Exploration)
|
||||
assert entity.chunk_count == 5
|
||||
|
||||
def test_exploration_with_entities(self):
|
||||
triples = [
|
||||
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
|
||||
("urn:exp:3", TG_EDGE_COUNT, "10"),
|
||||
("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"),
|
||||
("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:exp:3", triples)
|
||||
assert isinstance(entity, Exploration)
|
||||
assert len(entity.entities) == 2
|
||||
|
||||
def test_exploration_invalid_count(self):
|
||||
triples = [
|
||||
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
|
||||
|
|
@ -110,69 +135,76 @@ class TestExplainEntityFromTriples:
|
|||
assert "urn:edge:1" in entity.selected_edge_uris
|
||||
assert "urn:edge:2" in entity.selected_edge_uris
|
||||
|
||||
def test_synthesis_with_content(self):
|
||||
def test_synthesis_with_document(self):
|
||||
triples = [
|
||||
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:syn:1", TG_CONTENT, "The answer is 42"),
|
||||
("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:syn:1", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.content == "The answer is 42"
|
||||
assert entity.document_uri == ""
|
||||
assert entity.document_uri == "urn:doc:answer-1"
|
||||
|
||||
def test_synthesis_with_document(self):
|
||||
def test_synthesis_no_document(self):
|
||||
triples = [
|
||||
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:syn:2", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.document_uri == "urn:doc:answer-1"
|
||||
assert entity.document_uri == ""
|
||||
|
||||
def test_reflection_thought(self):
|
||||
triples = [
|
||||
("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE),
|
||||
("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE),
|
||||
("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ref:1", triples)
|
||||
assert isinstance(entity, Reflection)
|
||||
assert entity.reflection_type == "thought"
|
||||
assert entity.document_uri == "urn:doc:thought-1"
|
||||
|
||||
def test_reflection_observation(self):
|
||||
triples = [
|
||||
("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE),
|
||||
("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ref:2", triples)
|
||||
assert isinstance(entity, Reflection)
|
||||
assert entity.reflection_type == "observation"
|
||||
assert entity.document_uri == "urn:doc:obs-1"
|
||||
|
||||
def test_analysis(self):
|
||||
triples = [
|
||||
("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:ana:1", TG_THOUGHT, "I should search"),
|
||||
("urn:ana:1", TG_ACTION, "graph-rag-query"),
|
||||
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
|
||||
("urn:ana:1", TG_OBSERVATION, "Found results"),
|
||||
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
|
||||
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ana:1", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
assert entity.thought == "I should search"
|
||||
assert entity.action == "graph-rag-query"
|
||||
assert entity.arguments == '{"query": "test"}'
|
||||
assert entity.observation == "Found results"
|
||||
|
||||
def test_analysis_with_document_refs(self):
|
||||
triples = [
|
||||
("urn:ana:2", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:ana:2", TG_ACTION, "search"),
|
||||
("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"),
|
||||
("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ana:2", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
assert entity.thought_document_uri == "urn:doc:thought-1"
|
||||
assert entity.observation_document_uri == "urn:doc:obs-1"
|
||||
|
||||
def test_conclusion_with_answer(self):
|
||||
triples = [
|
||||
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:conc:1", TG_ANSWER, "The final answer"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:1", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.answer == "The final answer"
|
||||
assert entity.thought_uri == "urn:ref:thought-1"
|
||||
assert entity.observation_uri == "urn:ref:obs-1"
|
||||
|
||||
def test_conclusion_with_document(self):
|
||||
triples = [
|
||||
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:conc:1", TG_DOCUMENT, "urn:doc:final"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:1", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.document_uri == "urn:doc:final"
|
||||
|
||||
def test_conclusion_no_document(self):
|
||||
triples = [
|
||||
("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:conc:2", TG_DOCUMENT, "urn:doc:final"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:2", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.document_uri == "urn:doc:final"
|
||||
assert entity.document_uri == ""
|
||||
|
||||
def test_unknown_type(self):
|
||||
triples = [
|
||||
|
|
@ -457,25 +489,7 @@ class TestExplainabilityClientResolveLabel:
|
|||
|
||||
class TestExplainabilityClientContentFetching:
|
||||
|
||||
def test_fetch_synthesis_inline_content(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
synthesis = Synthesis(uri="urn:syn:1", content="inline answer")
|
||||
result = client.fetch_synthesis_content(synthesis, api=None)
|
||||
assert result == "inline answer"
|
||||
|
||||
def test_fetch_synthesis_truncated_content(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
long_content = "x" * 20000
|
||||
synthesis = Synthesis(uri="urn:syn:1", content=long_content)
|
||||
result = client.fetch_synthesis_content(synthesis, api=None, max_content=100)
|
||||
assert len(result) < 20000
|
||||
assert result.endswith("... [truncated]")
|
||||
|
||||
def test_fetch_synthesis_from_librarian(self):
|
||||
def test_fetch_document_content_from_librarian(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_library = MagicMock()
|
||||
|
|
@ -483,66 +497,32 @@ class TestExplainabilityClientContentFetching:
|
|||
mock_library.get_document_content.return_value = b"librarian content"
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
synthesis = Synthesis(
|
||||
uri="urn:syn:1",
|
||||
document_uri="urn:document:abc123"
|
||||
result = client.fetch_document_content(
|
||||
"urn:document:abc123", api=mock_api
|
||||
)
|
||||
result = client.fetch_synthesis_content(synthesis, api=mock_api)
|
||||
assert result == "librarian content"
|
||||
|
||||
def test_fetch_synthesis_no_content_or_document(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
synthesis = Synthesis(uri="urn:syn:1")
|
||||
result = client.fetch_synthesis_content(synthesis, api=None)
|
||||
assert result == ""
|
||||
|
||||
def test_fetch_conclusion_inline(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
conclusion = Conclusion(uri="urn:conc:1", answer="42")
|
||||
result = client.fetch_conclusion_content(conclusion, api=None)
|
||||
assert result == "42"
|
||||
|
||||
def test_fetch_analysis_content_from_librarian(self):
|
||||
def test_fetch_document_content_truncated(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_library = MagicMock()
|
||||
mock_api.library.return_value = mock_library
|
||||
mock_library.get_document_content.side_effect = [
|
||||
b"thought content",
|
||||
b"observation content",
|
||||
]
|
||||
mock_library.get_document_content.return_value = b"x" * 20000
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
analysis = Analysis(
|
||||
uri="urn:ana:1",
|
||||
action="search",
|
||||
thought_document_uri="urn:doc:thought",
|
||||
observation_document_uri="urn:doc:obs",
|
||||
result = client.fetch_document_content(
|
||||
"urn:doc:1", api=mock_api, max_content=100
|
||||
)
|
||||
client.fetch_analysis_content(analysis, api=mock_api)
|
||||
assert analysis.thought == "thought content"
|
||||
assert analysis.observation == "observation content"
|
||||
assert len(result) < 20000
|
||||
assert result.endswith("... [truncated]")
|
||||
|
||||
def test_fetch_analysis_skips_when_inline_exists(self):
|
||||
def test_fetch_document_content_empty_uri(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
analysis = Analysis(
|
||||
uri="urn:ana:1",
|
||||
action="search",
|
||||
thought="already have thought",
|
||||
observation="already have observation",
|
||||
thought_document_uri="urn:doc:thought",
|
||||
observation_document_uri="urn:doc:obs",
|
||||
)
|
||||
client.fetch_analysis_content(analysis, api=mock_api)
|
||||
# Should not call librarian since inline content exists
|
||||
mock_api.library.assert_not_called()
|
||||
result = client.fetch_document_content("", api=mock_api)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestExplainabilityClientDetectSessionType:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from trustgraph.provenance.triples import (
|
|||
derived_entity_triples,
|
||||
subgraph_provenance_triples,
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
|
|
@ -32,10 +33,12 @@ from trustgraph.provenance.namespaces import (
|
|||
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
|
||||
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_CONTENT, TG_DOCUMENT,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT,
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANSWER_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||||
GRAPH_SOURCE, GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
|
@ -530,36 +533,77 @@ class TestQuestionTriples:
|
|||
assert len(triples) == 6
|
||||
|
||||
|
||||
class TestExplorationTriples:
|
||||
class TestGroundingTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
|
||||
GND_URI = "urn:trustgraph:prov:grounding:test-session"
|
||||
Q_URI = "urn:trustgraph:question:test-session"
|
||||
|
||||
def test_exploration_types(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
def test_grounding_types(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"])
|
||||
assert has_type(triples, self.GND_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.GND_URI, TG_GROUNDING)
|
||||
|
||||
def test_exploration_generated_by_question(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
|
||||
def test_grounding_generated_by_question(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.Q_URI
|
||||
|
||||
def test_grounding_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
|
||||
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
|
||||
assert len(concepts) == 3
|
||||
values = {t.o.value for t in concepts}
|
||||
assert values == {"AI", "ML", "robots"}
|
||||
|
||||
def test_grounding_empty_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
|
||||
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
|
||||
assert len(concepts) == 0
|
||||
|
||||
def test_grounding_label(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
|
||||
label = find_triple(triples, RDFS_LABEL, self.GND_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Grounding"
|
||||
|
||||
|
||||
class TestExplorationTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
|
||||
GND_URI = "urn:trustgraph:prov:grounding:test-session"
|
||||
|
||||
def test_exploration_types(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
|
||||
def test_exploration_derived_from_grounding(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.GND_URI
|
||||
|
||||
def test_exploration_edge_count(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
|
||||
assert ec is not None
|
||||
assert ec.o.value == "15"
|
||||
|
||||
def test_exploration_zero_edges(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 0)
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 0)
|
||||
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
|
||||
assert ec is not None
|
||||
assert ec.o.value == "0"
|
||||
|
||||
def test_exploration_with_entities(self):
|
||||
entities = ["urn:e:machine-learning", "urn:e:neural-networks"]
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities)
|
||||
ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI)
|
||||
assert len(ent_triples) == 2
|
||||
|
||||
def test_exploration_triple_count(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 10)
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10)
|
||||
assert len(triples) == 5
|
||||
|
||||
|
||||
|
|
@ -652,6 +696,7 @@ class TestSynthesisTriples:
|
|||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
|
||||
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_synthesis_derived_from_focus(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
|
|
@ -659,12 +704,6 @@ class TestSynthesisTriples:
|
|||
assert derived is not None
|
||||
assert derived.o.iri == self.FOC_URI
|
||||
|
||||
def test_synthesis_with_inline_content(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42")
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content is not None
|
||||
assert content.o.value == "The answer is 42"
|
||||
|
||||
def test_synthesis_with_document_reference(self):
|
||||
triples = synthesis_triples(
|
||||
self.SYN_URI, self.FOC_URI,
|
||||
|
|
@ -675,23 +714,9 @@ class TestSynthesisTriples:
|
|||
assert doc.o.type == IRI
|
||||
assert doc.o.iri == "urn:trustgraph:question:abc/answer"
|
||||
|
||||
def test_synthesis_document_takes_precedence(self):
|
||||
"""When both document_id and answer_text are provided, document_id wins."""
|
||||
triples = synthesis_triples(
|
||||
self.SYN_URI, self.FOC_URI,
|
||||
answer_text="inline",
|
||||
document_id="urn:doc:123",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc is not None
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content is None
|
||||
|
||||
def test_synthesis_no_content_or_document(self):
|
||||
def test_synthesis_no_document(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert content is None
|
||||
assert doc is None
|
||||
|
||||
|
||||
|
|
@ -723,31 +748,31 @@ class TestDocRagQuestionTriples:
|
|||
class TestDocRagExplorationTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:docrag:test/exploration"
|
||||
Q_URI = "urn:trustgraph:docrag:test"
|
||||
GND_URI = "urn:trustgraph:docrag:test/grounding"
|
||||
|
||||
def test_docrag_exploration_types(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
|
||||
def test_docrag_exploration_generated_by(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
|
||||
assert gen.o.iri == self.Q_URI
|
||||
def test_docrag_exploration_derived_from_grounding(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
|
||||
assert derived.o.iri == self.GND_URI
|
||||
|
||||
def test_docrag_exploration_chunk_count(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7)
|
||||
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
|
||||
assert cc.o.value == "7"
|
||||
|
||||
def test_docrag_exploration_without_chunk_ids(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3)
|
||||
chunks = find_triples(triples, TG_SELECTED_CHUNK)
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_docrag_exploration_with_chunk_ids(self):
|
||||
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids)
|
||||
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
|
||||
assert len(chunks) == 3
|
||||
chunk_uris = {t.o.iri for t in chunks}
|
||||
|
|
@ -770,10 +795,9 @@ class TestDocRagSynthesisTriples:
|
|||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
|
||||
assert derived.o.iri == self.EXP_URI
|
||||
|
||||
def test_docrag_synthesis_with_inline(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer")
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content.o.value == "answer"
|
||||
def test_docrag_synthesis_has_answer_type(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_docrag_synthesis_with_document(self):
|
||||
triples = docrag_synthesis_triples(
|
||||
|
|
@ -781,5 +805,8 @@ class TestDocRagSynthesisTriples:
|
|||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc.o.iri == "urn:doc:ans"
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content is None
|
||||
|
||||
def test_docrag_synthesis_no_document(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc is None
|
||||
|
|
|
|||
|
|
@ -125,19 +125,15 @@ class TestQuery:
|
|||
assert query.doc_limit == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
async def test_extract_concepts(self):
|
||||
"""Test Query.extract_concepts extracts concepts from query"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock the embed method to return test vectors in batch format
|
||||
# New format: [[[vectors_for_text1]]] - returns first text's vector set
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
# Mock the prompt response with concept lines
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -145,20 +141,62 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What documents are relevant?"
|
||||
result = await query.get_vector(test_query)
|
||||
result = await query.extract_concepts("What is machine learning?")
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": "What is machine learning?"}
|
||||
)
|
||||
assert result == ["machine learning", "artificial intelligence", "data patterns"]
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_concepts_fallback_to_raw_query(self):
|
||||
"""Test Query.extract_concepts falls back to raw query when no concepts extracted"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock empty response
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = await query.extract_concepts("What is ML?")
|
||||
|
||||
assert result == ["What is ML?"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vectors_method(self):
|
||||
"""Test Query.get_vectors method calls embeddings client correctly"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method - returns vectors for each concept
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
concepts = ["machine learning", "data patterns"]
|
||||
result = await query.get_vectors(concepts)
|
||||
|
||||
mock_embeddings_client.embed.assert_called_once_with(concepts)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_method(self):
|
||||
"""Test Query.get_docs method retrieves documents correctly"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
|
@ -170,10 +208,8 @@ class TestQuery:
|
|||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Mock the embedding and document query responses
|
||||
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
# Mock embeddings - one vector per concept
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
# Mock document embeddings returns ChunkMatch objects
|
||||
mock_match1 = MagicMock()
|
||||
|
|
@ -184,7 +220,6 @@ class TestQuery:
|
|||
mock_match2.score = 0.85
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -193,16 +228,16 @@ class TestQuery:
|
|||
doc_limit=15
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
test_query = "Find relevant documents"
|
||||
result = await query.get_docs(test_query)
|
||||
# Call get_docs with concepts list
|
||||
concepts = ["test concept"]
|
||||
result = await query.get_docs(concepts)
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
# Verify embeddings client was called with concepts
|
||||
mock_embeddings_client.embed.assert_called_once_with(concepts)
|
||||
|
||||
# Verify doc embeddings client was called correctly (with extracted vector)
|
||||
# Verify doc embeddings client was called
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=test_vectors,
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=15,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
|
|
@ -218,14 +253,17 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_method(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock embeddings and document embeddings responses
|
||||
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "test concept"
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
mock_match1 = MagicMock()
|
||||
mock_match1.chunk_id = "doc/c3"
|
||||
mock_match1.score = 0.9
|
||||
|
|
@ -234,11 +272,9 @@ class TestQuery:
|
|||
mock_match2.score = 0.8
|
||||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -247,7 +283,6 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
|
|
@ -255,12 +290,18 @@ class TestQuery:
|
|||
doc_limit=10
|
||||
)
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["test query"])
|
||||
# Verify concept extraction was called
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": "test query"}
|
||||
)
|
||||
|
||||
# Verify doc embeddings client was called (with extracted vector)
|
||||
# Verify embeddings called with extracted concepts
|
||||
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
|
||||
|
||||
# Verify doc embeddings client was called
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=test_vectors,
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
|
|
@ -270,23 +311,23 @@ class TestQuery:
|
|||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.document_prompt.call_args
|
||||
assert call_args.kwargs["query"] == "test query"
|
||||
# Documents should be fetched content, not chunk_ids
|
||||
docs = call_args.kwargs["documents"]
|
||||
assert "Relevant document content" in docs
|
||||
assert "Another document" in docs
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method with default parameters"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses (batch format)
|
||||
# Mock concept extraction fallback (empty → raw query)
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = "doc/c5"
|
||||
|
|
@ -294,7 +335,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -302,10 +342,9 @@ class TestQuery:
|
|||
fetch_chunk=mock_fetch_chunk
|
||||
)
|
||||
|
||||
# Call DocumentRag.query with minimal parameters
|
||||
result = await document_rag.query("simple query")
|
||||
|
||||
# Verify default parameters were used (vector extracted from batch)
|
||||
# Verify default parameters were used
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[[0.1, 0.2]],
|
||||
limit=20, # Default doc_limit
|
||||
|
|
@ -318,7 +357,6 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
"""Test Query.get_docs method with verbose logging"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
|
@ -330,14 +368,13 @@ class TestQuery:
|
|||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Mock responses (batch format)
|
||||
# Mock responses - one vector per concept
|
||||
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = "doc/c6"
|
||||
mock_match.score = 0.88
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -346,14 +383,12 @@ class TestQuery:
|
|||
doc_limit=5
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("verbose test")
|
||||
# Call get_docs with concepts
|
||||
result = await query.get_docs(["verbose test"])
|
||||
|
||||
# Verify calls were made (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify result is tuple of (docs, chunk_ids) with fetched content
|
||||
docs, chunk_ids = result
|
||||
assert "Verbose test doc" in docs
|
||||
assert "doc/c6" in chunk_ids
|
||||
|
|
@ -361,12 +396,14 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method with verbose logging enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses (batch format)
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "verbose query test"
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = "doc/c7"
|
||||
|
|
@ -374,7 +411,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
|
||||
# Initialize DocumentRag with verbose=True
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -383,14 +419,11 @@ class TestQuery:
|
|||
verbose=True
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("verbose query test")
|
||||
|
||||
# Verify all clients were called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify prompt client was called with fetched content
|
||||
call_args = mock_prompt_client.document_prompt.call_args
|
||||
assert call_args.kwargs["query"] == "verbose query test"
|
||||
assert "Verbose doc content" in call_args.kwargs["documents"]
|
||||
|
|
@ -400,23 +433,20 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
"""Test Query.get_docs method when no documents are found"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock fetch_chunk (won't be called if no chunk_ids)
|
||||
async def mock_fetch(chunk_id, user):
|
||||
return f"Content for {chunk_id}"
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Mock responses - empty chunk_id list (batch format)
|
||||
# Mock responses - empty results
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -424,30 +454,27 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("query with no results")
|
||||
result = await query.get_docs(["query with no results"])
|
||||
|
||||
# Verify calls were made (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify empty result is returned (tuple of empty lists)
|
||||
assert result == ([], [])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method when no documents are retrieved"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses - no chunk_ids found (batch format)
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "query with no matching docs"
|
||||
|
||||
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -456,10 +483,8 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("query with no matching docs")
|
||||
|
||||
# Verify prompt client was called with empty document list
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="query with no matching docs",
|
||||
documents=[]
|
||||
|
|
@ -468,18 +493,15 @@ class TestQuery:
|
|||
assert result == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose logging"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
async def test_get_vectors_with_verbose(self):
|
||||
"""Test Query.get_vectors method with verbose logging"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method (batch format)
|
||||
expected_vectors = [[0.9, 1.0, 1.1]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -487,40 +509,40 @@ class TestQuery:
|
|||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
result = await query.get_vector("verbose vector test")
|
||||
result = await query.get_vectors(["verbose vector test"])
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
|
||||
|
||||
# Verify result (extracted from batch)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
|
||||
"""Test complete DocumentRag integration with realistic data flow"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock realistic responses (batch format)
|
||||
query_text = "What is machine learning?"
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
|
||||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
mock_embeddings_client.embed.return_value = [query_vectors]
|
||||
mock_matches = []
|
||||
for chunk_id in retrieved_chunk_ids:
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = chunk_id
|
||||
mock_match.score = 0.9
|
||||
mock_matches.append(mock_match)
|
||||
mock_doc_embeddings_client.query.return_value = mock_matches
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
|
||||
mock_embeddings_client.embed.return_value = query_vectors
|
||||
|
||||
# Each concept query returns some matches
|
||||
mock_matches_1 = [
|
||||
MagicMock(chunk_id="doc/ml1", score=0.9),
|
||||
MagicMock(chunk_id="doc/ml2", score=0.85),
|
||||
]
|
||||
mock_matches_2 = [
|
||||
MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate
|
||||
MagicMock(chunk_id="doc/ml3", score=0.82),
|
||||
]
|
||||
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -529,7 +551,6 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Execute full pipeline
|
||||
result = await document_rag.query(
|
||||
query=query_text,
|
||||
user="research_user",
|
||||
|
|
@ -537,26 +558,69 @@ class TestQuery:
|
|||
doc_limit=25
|
||||
)
|
||||
|
||||
# Verify complete pipeline execution (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([query_text])
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=query_vectors,
|
||||
limit=25,
|
||||
user="research_user",
|
||||
collection="ml_knowledge"
|
||||
# Verify concept extraction
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": query_text}
|
||||
)
|
||||
|
||||
# Verify embeddings called with concepts
|
||||
mock_embeddings_client.embed.assert_called_once_with(
|
||||
["machine learning", "artificial intelligence"]
|
||||
)
|
||||
|
||||
# Verify two per-concept queries were made (25 // 2 = 12 per concept)
|
||||
assert mock_doc_embeddings_client.query.call_count == 2
|
||||
|
||||
# Verify prompt client was called with fetched document content
|
||||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.document_prompt.call_args
|
||||
assert call_args.kwargs["query"] == query_text
|
||||
|
||||
# Verify documents were fetched from chunk_ids
|
||||
# Verify documents were fetched and deduplicated
|
||||
docs = call_args.kwargs["documents"]
|
||||
assert "Machine learning is a subset of artificial intelligence..." in docs
|
||||
assert "ML algorithms learn patterns from data to make predictions..." in docs
|
||||
assert "Common ML techniques include supervised and unsupervised learning..." in docs
|
||||
assert len(docs) == 3 # doc/ml2 deduplicated
|
||||
|
||||
# Verify final result
|
||||
assert result == final_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_deduplicates_across_concepts(self):
|
||||
"""Test that get_docs deduplicates chunks across multiple concepts"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
async def mock_fetch(chunk_id, user):
|
||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Two concepts → two vectors
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# Both queries return overlapping chunks
|
||||
match_a = MagicMock(chunk_id="doc/c1", score=0.9)
|
||||
match_b = MagicMock(chunk_id="doc/c2", score=0.8)
|
||||
match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate
|
||||
mock_doc_embeddings_client.query.side_effect = [
|
||||
[match_a, match_b],
|
||||
[match_c],
|
||||
]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
|
||||
|
||||
assert len(chunk_ids) == 2 # doc/c1 only counted once
|
||||
assert "doc/c1" in chunk_ids
|
||||
assert "doc/c2" in chunk_ids
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class TestGraphRag:
|
|||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -27,7 +27,7 @@ class TestGraphRag:
|
|||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
|
|
@ -45,7 +45,7 @@ class TestGraphRag:
|
|||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -54,7 +54,7 @@ class TestGraphRag:
|
|||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
|
|
@ -73,7 +73,7 @@ class TestQuery:
|
|||
"""Test Query initialization with default parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -81,7 +81,7 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
|
|
@ -96,7 +96,7 @@ class TestQuery:
|
|||
"""Test Query initialization with custom parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
# Initialize Query with custom parameters
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -108,7 +108,7 @@ class TestQuery:
|
|||
max_subgraph_size=2000,
|
||||
max_path_length=3
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
|
|
@ -120,18 +120,16 @@ class TestQuery:
|
|||
assert query.max_path_length == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
async def test_get_vectors_method(self):
|
||||
"""Test Query.get_vectors method calls embeddings client correctly"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method to return test vectors (batch format)
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
|
||||
# Initialize Query
|
||||
# Mock embed to return vectors for a list of concepts
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -139,29 +137,22 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What is the capital of France?"
|
||||
result = await query.get_vector(test_query)
|
||||
concepts = ["machine learning", "neural networks"]
|
||||
result = await query.get_vectors(concepts)
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
mock_embeddings_client.embed.assert_called_once_with(concepts)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose output"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
async def test_get_vectors_method_with_verbose(self):
|
||||
"""Test Query.get_vectors method with verbose output"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method (batch format)
|
||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -169,48 +160,87 @@ class TestQuery:
|
|||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "Test query for embeddings"
|
||||
result = await query.get_vector(test_query)
|
||||
result = await query.get_vectors(["test concept"])
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities_method(self):
|
||||
"""Test Query.get_entities method retrieves entities correctly"""
|
||||
# Create mock GraphRag with clients
|
||||
async def test_extract_concepts(self):
|
||||
"""Test Query.extract_concepts parses LLM response into concept list"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = await query.extract_concepts("What is machine learning?")
|
||||
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": "What is machine learning?"}
|
||||
)
|
||||
assert result == ["machine learning", "neural networks"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_concepts_fallback_to_raw_query(self):
|
||||
"""Test extract_concepts falls back to raw query when LLM returns empty"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = await query.extract_concepts("test query")
|
||||
assert result == ["test query"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities_method(self):
|
||||
"""Test Query.get_entities extracts concepts, embeds, and retrieves entities"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# Mock the embedding and entity query responses (batch format)
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
|
||||
# Mock EntityMatch objects with entity as Term-like object
|
||||
# extract_concepts returns empty -> falls back to [query]
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
# embed returns one vector set for the single concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
# Mock entity matches
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.type = "i" # IRI type
|
||||
mock_entity1.type = "i"
|
||||
mock_entity1.iri = "entity1"
|
||||
mock_match1 = MagicMock()
|
||||
mock_match1.entity = mock_entity1
|
||||
mock_match1.score = 0.95
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.type = "i" # IRI type
|
||||
mock_entity2.type = "i"
|
||||
mock_entity2.iri = "entity2"
|
||||
mock_match2 = MagicMock()
|
||||
mock_match2.entity = mock_entity2
|
||||
mock_match2.score = 0.85
|
||||
|
||||
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -219,35 +249,23 @@ class TestQuery:
|
|||
entity_limit=25
|
||||
)
|
||||
|
||||
# Call get_entities
|
||||
test_query = "Find related entities"
|
||||
result = await query.get_entities(test_query)
|
||||
entities, concepts = await query.get_entities("Find related entities")
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
# Verify embeddings client was called with the fallback concept
|
||||
mock_embeddings_client.embed.assert_called_once_with(["Find related entities"])
|
||||
|
||||
# Verify graph embeddings client was called correctly (with extracted vector)
|
||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||
vector=test_vectors,
|
||||
limit=25,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is list of entity strings
|
||||
assert result == ["entity1", "entity2"]
|
||||
# Verify result
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["Find related entities"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_cached_label(self):
|
||||
"""Test Query.maybe_label method with cached label"""
|
||||
# Create mock GraphRag with label cache
|
||||
mock_rag = MagicMock()
|
||||
# Create mock LRUCacheWithTTL
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = "Entity One Label"
|
||||
mock_rag.label_cache = mock_cache
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -255,32 +273,25 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label with cached entity
|
||||
result = await query.maybe_label("entity1")
|
||||
|
||||
# Verify cached label is returned
|
||||
assert result == "Entity One Label"
|
||||
# Verify cache was checked with proper key format (user:collection:entity)
|
||||
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
"""Test Query.maybe_label method with database label lookup"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple result with label
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Human Readable Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -288,20 +299,18 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("http://example.com/entity")
|
||||
|
||||
# Verify triples client was called correctly
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="http://example.com/entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
# Verify result and cache update with proper key
|
||||
assert result == "Human Readable Label"
|
||||
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||
|
|
@ -309,40 +318,34 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_no_label_found(self):
|
||||
"""Test Query.maybe_label method when no label is found"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock empty result (no label found)
|
||||
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
# Initialize Query
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
|
||||
result = await query.maybe_label("unlabeled_entity")
|
||||
|
||||
# Verify triples client was called
|
||||
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="unlabeled_entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
# Verify result is entity itself and cache is updated
|
||||
|
||||
assert result == "unlabeled_entity"
|
||||
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
|
@ -350,29 +353,25 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple results for different query patterns
|
||||
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
|
||||
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent, p=None, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple3], # s=None, p=None, o=ent
|
||||
[mock_triple1], # s=ent
|
||||
[mock_triple2], # p=ent
|
||||
[mock_triple3], # o=ent
|
||||
]
|
||||
|
||||
# Initialize Query
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -380,29 +379,25 @@ class TestQuery:
|
|||
verbose=False,
|
||||
triple_limit=10
|
||||
)
|
||||
|
||||
# Call follow_edges
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify all three query patterns were called
|
||||
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
# Verify query_stream calls
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
|
||||
# Verify subgraph contains discovered triples
|
||||
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
|
|
@ -413,38 +408,30 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call follow_edges with path_length=0
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
# Verify no queries were made
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph remains empty
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query with small max_subgraph_size
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -452,23 +439,17 @@ class TestQuery:
|
|||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
)
|
||||
|
||||
# Pre-populate subgraph to exceed limit
|
||||
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
# Call follow_edges
|
||||
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify no queries were made due to size limit
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph unchanged
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||
# Create mock Query that patches get_entities and follow_edges_batch
|
||||
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
|
|
@ -479,130 +460,119 @@ class TestQuery:
|
|||
max_path_length=1
|
||||
)
|
||||
|
||||
# Mock get_entities to return test entities
|
||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||
# Mock get_entities to return (entities, concepts) tuple
|
||||
query.get_entities = AsyncMock(
|
||||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
# Mock follow_edges_batch to return test triples
|
||||
query.follow_edges_batch = AsyncMock(return_value={
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
})
|
||||
|
||||
# Call get_subgraph
|
||||
result = await query.get_subgraph("test query")
|
||||
subgraph, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
# Verify get_entities was called
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
|
||||
# Verify follow_edges_batch was called with entities and max_path_length
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
# Verify result is list format and contains expected triples
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert ("entity1", "predicate1", "object1") in result
|
||||
assert ("entity2", "predicate2", "object2") in result
|
||||
assert isinstance(subgraph, list)
|
||||
assert len(subgraph) == 2
|
||||
assert ("entity1", "predicate1", "object1") in subgraph
|
||||
assert ("entity2", "predicate2", "object2") in subgraph
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["concept1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph method converts entities to labels"""
|
||||
# Create mock Query
|
||||
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
# Mock get_subgraph to return test triples
|
||||
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
query.get_subgraph = AsyncMock(return_value=test_subgraph)
|
||||
|
||||
# Mock maybe_label to return human-readable labels
|
||||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
# Call get_labelgraph
|
||||
labeled_edges, uri_map = await query.get_labelgraph("test query")
|
||||
|
||||
# Verify get_subgraph was called
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
|
||||
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Verify label triples are filtered out
|
||||
assert len(labeled_edges) == 2 # Label triple should be excluded
|
||||
# Label triples filtered out
|
||||
assert len(labeled_edges) == 2
|
||||
|
||||
# Verify maybe_label was called for non-label triples
|
||||
expected_calls = [
|
||||
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
|
||||
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
|
||||
]
|
||||
# maybe_label called for non-label triples
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
# Verify result contains human-readable labels
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
# Verify uri_map maps labeled edges back to original URIs
|
||||
assert len(uri_map) == 2
|
||||
assert entities == test_entities
|
||||
assert concepts == test_concepts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance"""
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
|
||||
# Mock prompt client responses for two-step process
|
||||
expected_response = "This is the RAG response"
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
# Compute the edge ID for the test edge
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
|
||||
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
test_entities = ["http://example.org/subject"]
|
||||
test_concepts = ["test concept"]
|
||||
|
||||
# Mock edge selection response (JSONL format)
|
||||
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
|
||||
# Configure prompt mock to return different responses based on prompt name
|
||||
# Mock prompt responses for the multi-step process
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return edge_selection_response
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return json.dumps({"id": test_edge_id, "score": 0.9})
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -611,27 +581,20 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# We need to patch the Query class's get_labelgraph method
|
||||
original_query_init = Query.__init__
|
||||
# Patch Query.get_labelgraph to return test data
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
|
||||
def mock_query_init(self, *args, **kwargs):
|
||||
original_query_init(self, *args, **kwargs)
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph, test_uri_map
|
||||
return test_labelgraph, test_uri_map, test_entities, test_concepts
|
||||
|
||||
Query.__init__ = mock_query_init
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
|
||||
# Collect provenance emitted via callback
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
try:
|
||||
# Call GraphRag.query with provenance callback
|
||||
response = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
|
|
@ -641,25 +604,22 @@ class TestQuery:
|
|||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Verify response text
|
||||
assert response == expected_response
|
||||
|
||||
# Verify provenance was emitted incrementally (4 events: question, exploration, focus, synthesis)
|
||||
assert len(provenance_events) == 4
|
||||
# 5 events: question, grounding, exploration, focus, synthesis
|
||||
assert len(provenance_events) == 5
|
||||
|
||||
# Verify each event has triples and a URN
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order: question, exploration, focus, synthesis
|
||||
# Verify order
|
||||
assert "question" in provenance_events[0][1]
|
||||
assert "exploration" in provenance_events[1][1]
|
||||
assert "focus" in provenance_events[2][1]
|
||||
assert "synthesis" in provenance_events[3][1]
|
||||
assert "grounding" in provenance_events[1][1]
|
||||
assert "exploration" in provenance_events[2][1]
|
||||
assert "focus" in provenance_events[3][1]
|
||||
assert "synthesis" in provenance_events[4][1]
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
Query.__init__ = original_query_init
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue