mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding,
consistent PROV-O GraphRAG: - Split retrieval into 4 prompt stages: extract-concepts, kg-edge-scoring, kg-edge-reasoning, kg-synthesis (was single-stage) - Add concept extraction (grounding) for per-concept embedding - Filter main query to default graph, ignoring provenance/explainability edges - Add source document edges to knowledge graph DocumentRAG: - Add grounding step with concept extraction, matching GraphRAG's pattern: Question → Grounding → Exploration → Synthesis - Per-concept embedding and chunk retrieval with deduplication Cross-pipeline: - Make PROV-O derivation links consistent: wasGeneratedBy for first entity from Activity, wasDerivedFrom for entity-to-entity chains - Update CLIs (tg-invoke-agent, tg-invoke-graph-rag, tg-invoke-document-rag) for new explainability structure - Fix all affected unit and integration tests
This commit is contained in:
parent
29b4300808
commit
20bb645b9a
25 changed files with 1537 additions and 1008 deletions
|
|
@ -86,13 +86,18 @@ class TestGraphRagIntegration:
|
|||
"""Mock prompt client that generates realistic responses for two-step process"""
|
||||
client = AsyncMock()
|
||||
|
||||
# Mock responses for the two-step process:
|
||||
# 1. kg-edge-selection returns JSONL with edge IDs
|
||||
# 2. kg-synthesis returns the final answer
|
||||
# Mock responses for the multi-step process:
|
||||
# 1. extract-concepts extracts key concepts from the query
|
||||
# 2. kg-edge-scoring scores edges for relevance
|
||||
# 3. kg-edge-reasoning provides reasoning for selected edges
|
||||
# 4. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Return empty selection (no edges selected) - valid JSONL
|
||||
return ""
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return "" # No edges scored
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return "" # No reasoning
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
|
|
@ -160,16 +165,16 @@ class TestGraphRagIntegration:
|
|||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt twice (edge selection + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 2
|
||||
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 4
|
||||
|
||||
# Verify final response
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (4 events: question, exploration, focus, synthesis)
|
||||
assert len(provenance_events) == 4
|
||||
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
|
||||
assert len(provenance_events) == 5
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
|
@ -243,10 +248,10 @@ class TestGraphRagIntegration:
|
|||
)
|
||||
|
||||
# Assert
|
||||
# Should still call prompt client (twice: edge selection + synthesis)
|
||||
# Should still call prompt client
|
||||
assert response is not None
|
||||
# Provenance should still be emitted (4 events)
|
||||
assert len(provenance_events) == 4
|
||||
# Provenance should still be emitted (5 events)
|
||||
assert len(provenance_events) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
|
||||
|
|
|
|||
|
|
@ -60,8 +60,12 @@ class TestGraphRagStreaming:
|
|||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "kg-edge-selection":
|
||||
# Edge selection returns JSONL with IDs - simulate selecting first edge
|
||||
if prompt_id == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return '{"id": "abc12345", "score": 0.9}\n'
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -132,8 +136,8 @@ class TestGraphRagStreaming:
|
|||
# Verify content is reasonable
|
||||
assert "machine" in response.lower() or "learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (4 events)
|
||||
assert len(provenance_events) == 4
|
||||
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
|
||||
assert len(provenance_events) == 5
|
||||
for triples, prov_id in provenance_events:
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,11 @@ from trustgraph.provenance.agent import (
|
|||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
|
@ -110,84 +111,107 @@ class TestAgentSessionTriples:
|
|||
class TestAgentIterationTriples:
|
||||
|
||||
ITER_URI = "urn:trustgraph:agent:test-session/i1"
|
||||
PARENT_URI = "urn:trustgraph:agent:test-session"
|
||||
SESSION_URI = "urn:trustgraph:agent:test-session"
|
||||
PREV_URI = "urn:trustgraph:agent:test-session/i0"
|
||||
|
||||
def test_iteration_types(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
thought="thinking", action="search", observation="found it",
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
|
||||
|
||||
def test_iteration_derived_from_parent(self):
|
||||
def test_first_iteration_generated_by_question(self):
|
||||
"""First iteration uses wasGeneratedBy to link to question activity."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
# Should NOT have wasDerivedFrom
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is None
|
||||
|
||||
def test_subsequent_iteration_derived_from_previous(self):
|
||||
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, previous_uri=self.PREV_URI,
|
||||
action="search",
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PARENT_URI
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
# Should NOT have wasGeneratedBy
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_iteration_label_includes_action(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="graph-rag-query",
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
|
||||
assert label is not None
|
||||
assert "graph-rag-query" in label.o.value
|
||||
|
||||
def test_iteration_thought_inline(self):
|
||||
def test_iteration_thought_sub_entity(self):
|
||||
"""Thought is a sub-entity with Reflection and Thought types."""
|
||||
thought_uri = "urn:trustgraph:agent:test-session/i1/thought"
|
||||
thought_doc = "urn:doc:thought-1"
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
thought="I need to search for info",
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
thought_uri=thought_uri,
|
||||
thought_document_id=thought_doc,
|
||||
)
|
||||
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
assert thought is not None
|
||||
assert thought.o.value == "I need to search for info"
|
||||
# Iteration links to thought sub-entity
|
||||
thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
assert thought_link is not None
|
||||
assert thought_link.o.iri == thought_uri
|
||||
# Thought has correct types
|
||||
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
|
||||
# Thought was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Thought has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == thought_doc
|
||||
|
||||
def test_iteration_thought_document_preferred(self):
|
||||
"""When thought_document_id is provided, inline thought is not stored."""
|
||||
def test_iteration_observation_sub_entity(self):
|
||||
"""Observation is a sub-entity with Reflection and Observation types."""
|
||||
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
|
||||
obs_doc = "urn:doc:obs-1"
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
thought="inline thought",
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
thought_document_id="urn:doc:thought-1",
|
||||
observation_uri=obs_uri,
|
||||
observation_document_id=obs_doc,
|
||||
)
|
||||
thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI)
|
||||
assert thought_doc is not None
|
||||
assert thought_doc.o.iri == "urn:doc:thought-1"
|
||||
thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
assert thought_inline is None
|
||||
|
||||
def test_iteration_observation_inline(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
action="search",
|
||||
observation="Found 3 results",
|
||||
)
|
||||
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs is not None
|
||||
assert obs.o.value == "Found 3 results"
|
||||
|
||||
def test_iteration_observation_document_preferred(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
action="search",
|
||||
observation="inline obs",
|
||||
observation_document_id="urn:doc:obs-1",
|
||||
)
|
||||
obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI)
|
||||
assert obs_doc is not None
|
||||
assert obs_doc.o.iri == "urn:doc:obs-1"
|
||||
obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs_inline is None
|
||||
# Iteration links to observation sub-entity
|
||||
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs_link is not None
|
||||
assert obs_link.o.iri == obs_uri
|
||||
# Observation has correct types
|
||||
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
|
||||
# Observation was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Observation has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == obs_doc
|
||||
|
||||
def test_iteration_action_recorded(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="graph-rag-query",
|
||||
)
|
||||
action = find_triple(triples, TG_ACTION, self.ITER_URI)
|
||||
|
|
@ -197,7 +221,7 @@ class TestAgentIterationTriples:
|
|||
def test_iteration_arguments_json_encoded(self):
|
||||
args = {"query": "test query", "limit": 10}
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
arguments=args,
|
||||
)
|
||||
|
|
@ -208,7 +232,7 @@ class TestAgentIterationTriples:
|
|||
|
||||
def test_iteration_default_arguments_empty_dict(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
|
||||
|
|
@ -219,7 +243,7 @@ class TestAgentIterationTriples:
|
|||
def test_iteration_no_thought_or_observation(self):
|
||||
"""Minimal iteration with just action — no thought or observation triples."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, self.PARENT_URI,
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="noop",
|
||||
)
|
||||
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
|
|
@ -228,19 +252,19 @@ class TestAgentIterationTriples:
|
|||
assert obs is None
|
||||
|
||||
def test_iteration_chaining(self):
|
||||
"""Second iteration derives from first iteration, not session."""
|
||||
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
|
||||
iter1_uri = "urn:trustgraph:agent:sess/i1"
|
||||
iter2_uri = "urn:trustgraph:agent:sess/i2"
|
||||
|
||||
triples1 = agent_iteration_triples(
|
||||
iter1_uri, self.PARENT_URI, action="step1",
|
||||
iter1_uri, question_uri=self.SESSION_URI, action="step1",
|
||||
)
|
||||
triples2 = agent_iteration_triples(
|
||||
iter2_uri, iter1_uri, action="step2",
|
||||
iter2_uri, previous_uri=iter1_uri, action="step2",
|
||||
)
|
||||
|
||||
derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri)
|
||||
assert derived1.o.iri == self.PARENT_URI
|
||||
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
|
||||
assert gen1.o.iri == self.SESSION_URI
|
||||
|
||||
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
|
||||
assert derived2.o.iri == iter1_uri
|
||||
|
|
@ -253,42 +277,50 @@ class TestAgentIterationTriples:
|
|||
class TestAgentFinalTriples:
|
||||
|
||||
FINAL_URI = "urn:trustgraph:agent:test-session/final"
|
||||
PARENT_URI = "urn:trustgraph:agent:test-session/i3"
|
||||
PREV_URI = "urn:trustgraph:agent:test-session/i3"
|
||||
SESSION_URI = "urn:trustgraph:agent:test-session"
|
||||
|
||||
def test_final_types(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="42"
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
|
||||
assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_final_derived_from_parent(self):
|
||||
def test_final_derived_from_previous(self):
|
||||
"""Conclusion with iterations uses wasDerivedFrom."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="42"
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PARENT_URI
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_final_generated_by_question_when_no_iterations(self):
|
||||
"""When agent answers immediately, final uses wasGeneratedBy."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, question_uri=self.SESSION_URI,
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is None
|
||||
|
||||
def test_final_label(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="42"
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Conclusion"
|
||||
|
||||
def test_final_inline_answer(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI, answer="The answer is 42"
|
||||
)
|
||||
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
|
||||
assert answer is not None
|
||||
assert answer.o.value == "The answer is 42"
|
||||
|
||||
def test_final_document_reference(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI,
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
document_id="urn:trustgraph:agent:sess/answer",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
|
|
@ -296,29 +328,9 @@ class TestAgentFinalTriples:
|
|||
assert doc.o.type == IRI
|
||||
assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
|
||||
|
||||
def test_final_document_takes_precedence(self):
|
||||
def test_final_no_document(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, self.PARENT_URI,
|
||||
answer="inline",
|
||||
document_id="urn:doc:123",
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert doc is not None
|
||||
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
|
||||
assert answer is None
|
||||
|
||||
def test_final_no_answer_or_document(self):
|
||||
triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI)
|
||||
answer = find_triple(triples, TG_ANSWER, self.FINAL_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert answer is None
|
||||
assert doc is None
|
||||
|
||||
def test_final_derives_from_session_when_no_iterations(self):
|
||||
"""When agent answers immediately, final derives from session."""
|
||||
session_uri = "urn:trustgraph:agent:test-session"
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, session_uri, answer="direct answer"
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived.o.iri == session_uri
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ from trustgraph.api.explainability import (
|
|||
EdgeSelection,
|
||||
ExplainEntity,
|
||||
Question,
|
||||
Grounding,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
Reflection,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
parse_edge_selection_triples,
|
||||
|
|
@ -20,11 +22,11 @@ from trustgraph.api.explainability import (
|
|||
wire_triples_to_tuples,
|
||||
ExplainabilityClient,
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
|
||||
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
|
|
@ -71,6 +73,18 @@ class TestExplainEntityFromTriples:
|
|||
assert isinstance(entity, Question)
|
||||
assert entity.question_type == "agent"
|
||||
|
||||
def test_grounding(self):
|
||||
triples = [
|
||||
("urn:gnd:1", RDF_TYPE, TG_GROUNDING),
|
||||
("urn:gnd:1", TG_CONCEPT, "machine learning"),
|
||||
("urn:gnd:1", TG_CONCEPT, "neural networks"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:gnd:1", triples)
|
||||
assert isinstance(entity, Grounding)
|
||||
assert len(entity.concepts) == 2
|
||||
assert "machine learning" in entity.concepts
|
||||
assert "neural networks" in entity.concepts
|
||||
|
||||
def test_exploration(self):
|
||||
triples = [
|
||||
("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
|
||||
|
|
@ -89,6 +103,17 @@ class TestExplainEntityFromTriples:
|
|||
assert isinstance(entity, Exploration)
|
||||
assert entity.chunk_count == 5
|
||||
|
||||
def test_exploration_with_entities(self):
|
||||
triples = [
|
||||
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
|
||||
("urn:exp:3", TG_EDGE_COUNT, "10"),
|
||||
("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"),
|
||||
("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:exp:3", triples)
|
||||
assert isinstance(entity, Exploration)
|
||||
assert len(entity.entities) == 2
|
||||
|
||||
def test_exploration_invalid_count(self):
|
||||
triples = [
|
||||
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
|
||||
|
|
@ -110,69 +135,76 @@ class TestExplainEntityFromTriples:
|
|||
assert "urn:edge:1" in entity.selected_edge_uris
|
||||
assert "urn:edge:2" in entity.selected_edge_uris
|
||||
|
||||
def test_synthesis_with_content(self):
|
||||
def test_synthesis_with_document(self):
|
||||
triples = [
|
||||
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:syn:1", TG_CONTENT, "The answer is 42"),
|
||||
("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:syn:1", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.content == "The answer is 42"
|
||||
assert entity.document_uri == ""
|
||||
assert entity.document_uri == "urn:doc:answer-1"
|
||||
|
||||
def test_synthesis_with_document(self):
|
||||
def test_synthesis_no_document(self):
|
||||
triples = [
|
||||
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:syn:2", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.document_uri == "urn:doc:answer-1"
|
||||
assert entity.document_uri == ""
|
||||
|
||||
def test_reflection_thought(self):
|
||||
triples = [
|
||||
("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE),
|
||||
("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE),
|
||||
("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ref:1", triples)
|
||||
assert isinstance(entity, Reflection)
|
||||
assert entity.reflection_type == "thought"
|
||||
assert entity.document_uri == "urn:doc:thought-1"
|
||||
|
||||
def test_reflection_observation(self):
|
||||
triples = [
|
||||
("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE),
|
||||
("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ref:2", triples)
|
||||
assert isinstance(entity, Reflection)
|
||||
assert entity.reflection_type == "observation"
|
||||
assert entity.document_uri == "urn:doc:obs-1"
|
||||
|
||||
def test_analysis(self):
|
||||
triples = [
|
||||
("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:ana:1", TG_THOUGHT, "I should search"),
|
||||
("urn:ana:1", TG_ACTION, "graph-rag-query"),
|
||||
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
|
||||
("urn:ana:1", TG_OBSERVATION, "Found results"),
|
||||
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
|
||||
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ana:1", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
assert entity.thought == "I should search"
|
||||
assert entity.action == "graph-rag-query"
|
||||
assert entity.arguments == '{"query": "test"}'
|
||||
assert entity.observation == "Found results"
|
||||
|
||||
def test_analysis_with_document_refs(self):
|
||||
triples = [
|
||||
("urn:ana:2", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:ana:2", TG_ACTION, "search"),
|
||||
("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"),
|
||||
("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ana:2", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
assert entity.thought_document_uri == "urn:doc:thought-1"
|
||||
assert entity.observation_document_uri == "urn:doc:obs-1"
|
||||
|
||||
def test_conclusion_with_answer(self):
|
||||
triples = [
|
||||
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:conc:1", TG_ANSWER, "The final answer"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:1", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.answer == "The final answer"
|
||||
assert entity.thought_uri == "urn:ref:thought-1"
|
||||
assert entity.observation_uri == "urn:ref:obs-1"
|
||||
|
||||
def test_conclusion_with_document(self):
|
||||
triples = [
|
||||
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:conc:1", TG_DOCUMENT, "urn:doc:final"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:1", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.document_uri == "urn:doc:final"
|
||||
|
||||
def test_conclusion_no_document(self):
|
||||
triples = [
|
||||
("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:conc:2", TG_DOCUMENT, "urn:doc:final"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:2", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.document_uri == "urn:doc:final"
|
||||
assert entity.document_uri == ""
|
||||
|
||||
def test_unknown_type(self):
|
||||
triples = [
|
||||
|
|
@ -457,25 +489,7 @@ class TestExplainabilityClientResolveLabel:
|
|||
|
||||
class TestExplainabilityClientContentFetching:
|
||||
|
||||
def test_fetch_synthesis_inline_content(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
synthesis = Synthesis(uri="urn:syn:1", content="inline answer")
|
||||
result = client.fetch_synthesis_content(synthesis, api=None)
|
||||
assert result == "inline answer"
|
||||
|
||||
def test_fetch_synthesis_truncated_content(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
long_content = "x" * 20000
|
||||
synthesis = Synthesis(uri="urn:syn:1", content=long_content)
|
||||
result = client.fetch_synthesis_content(synthesis, api=None, max_content=100)
|
||||
assert len(result) < 20000
|
||||
assert result.endswith("... [truncated]")
|
||||
|
||||
def test_fetch_synthesis_from_librarian(self):
|
||||
def test_fetch_document_content_from_librarian(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_library = MagicMock()
|
||||
|
|
@ -483,66 +497,32 @@ class TestExplainabilityClientContentFetching:
|
|||
mock_library.get_document_content.return_value = b"librarian content"
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
synthesis = Synthesis(
|
||||
uri="urn:syn:1",
|
||||
document_uri="urn:document:abc123"
|
||||
result = client.fetch_document_content(
|
||||
"urn:document:abc123", api=mock_api
|
||||
)
|
||||
result = client.fetch_synthesis_content(synthesis, api=mock_api)
|
||||
assert result == "librarian content"
|
||||
|
||||
def test_fetch_synthesis_no_content_or_document(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
synthesis = Synthesis(uri="urn:syn:1")
|
||||
result = client.fetch_synthesis_content(synthesis, api=None)
|
||||
assert result == ""
|
||||
|
||||
def test_fetch_conclusion_inline(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
|
||||
conclusion = Conclusion(uri="urn:conc:1", answer="42")
|
||||
result = client.fetch_conclusion_content(conclusion, api=None)
|
||||
assert result == "42"
|
||||
|
||||
def test_fetch_analysis_content_from_librarian(self):
|
||||
def test_fetch_document_content_truncated(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_library = MagicMock()
|
||||
mock_api.library.return_value = mock_library
|
||||
mock_library.get_document_content.side_effect = [
|
||||
b"thought content",
|
||||
b"observation content",
|
||||
]
|
||||
mock_library.get_document_content.return_value = b"x" * 20000
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
analysis = Analysis(
|
||||
uri="urn:ana:1",
|
||||
action="search",
|
||||
thought_document_uri="urn:doc:thought",
|
||||
observation_document_uri="urn:doc:obs",
|
||||
result = client.fetch_document_content(
|
||||
"urn:doc:1", api=mock_api, max_content=100
|
||||
)
|
||||
client.fetch_analysis_content(analysis, api=mock_api)
|
||||
assert analysis.thought == "thought content"
|
||||
assert analysis.observation == "observation content"
|
||||
assert len(result) < 20000
|
||||
assert result.endswith("... [truncated]")
|
||||
|
||||
def test_fetch_analysis_skips_when_inline_exists(self):
|
||||
def test_fetch_document_content_empty_uri(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
analysis = Analysis(
|
||||
uri="urn:ana:1",
|
||||
action="search",
|
||||
thought="already have thought",
|
||||
observation="already have observation",
|
||||
thought_document_uri="urn:doc:thought",
|
||||
observation_document_uri="urn:doc:obs",
|
||||
)
|
||||
client.fetch_analysis_content(analysis, api=mock_api)
|
||||
# Should not call librarian since inline content exists
|
||||
mock_api.library.assert_not_called()
|
||||
result = client.fetch_document_content("", api=mock_api)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestExplainabilityClientDetectSessionType:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from trustgraph.provenance.triples import (
|
|||
derived_entity_triples,
|
||||
subgraph_provenance_triples,
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
|
|
@ -32,10 +33,12 @@ from trustgraph.provenance.namespaces import (
|
|||
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
|
||||
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_CONTENT, TG_DOCUMENT,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT,
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANSWER_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||||
GRAPH_SOURCE, GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
|
@ -530,36 +533,77 @@ class TestQuestionTriples:
|
|||
assert len(triples) == 6
|
||||
|
||||
|
||||
class TestExplorationTriples:
|
||||
class TestGroundingTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
|
||||
GND_URI = "urn:trustgraph:prov:grounding:test-session"
|
||||
Q_URI = "urn:trustgraph:question:test-session"
|
||||
|
||||
def test_exploration_types(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
def test_grounding_types(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"])
|
||||
assert has_type(triples, self.GND_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.GND_URI, TG_GROUNDING)
|
||||
|
||||
def test_exploration_generated_by_question(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
|
||||
def test_grounding_generated_by_question(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.Q_URI
|
||||
|
||||
def test_grounding_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
|
||||
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
|
||||
assert len(concepts) == 3
|
||||
values = {t.o.value for t in concepts}
|
||||
assert values == {"AI", "ML", "robots"}
|
||||
|
||||
def test_grounding_empty_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
|
||||
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
|
||||
assert len(concepts) == 0
|
||||
|
||||
def test_grounding_label(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
|
||||
label = find_triple(triples, RDFS_LABEL, self.GND_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Grounding"
|
||||
|
||||
|
||||
class TestExplorationTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
|
||||
GND_URI = "urn:trustgraph:prov:grounding:test-session"
|
||||
|
||||
def test_exploration_types(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
|
||||
def test_exploration_derived_from_grounding(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.GND_URI
|
||||
|
||||
def test_exploration_edge_count(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 15)
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
|
||||
assert ec is not None
|
||||
assert ec.o.value == "15"
|
||||
|
||||
def test_exploration_zero_edges(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 0)
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 0)
|
||||
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
|
||||
assert ec is not None
|
||||
assert ec.o.value == "0"
|
||||
|
||||
def test_exploration_with_entities(self):
|
||||
entities = ["urn:e:machine-learning", "urn:e:neural-networks"]
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities)
|
||||
ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI)
|
||||
assert len(ent_triples) == 2
|
||||
|
||||
def test_exploration_triple_count(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.Q_URI, 10)
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10)
|
||||
assert len(triples) == 5
|
||||
|
||||
|
||||
|
|
@ -652,6 +696,7 @@ class TestSynthesisTriples:
|
|||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
|
||||
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_synthesis_derived_from_focus(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
|
|
@ -659,12 +704,6 @@ class TestSynthesisTriples:
|
|||
assert derived is not None
|
||||
assert derived.o.iri == self.FOC_URI
|
||||
|
||||
def test_synthesis_with_inline_content(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42")
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content is not None
|
||||
assert content.o.value == "The answer is 42"
|
||||
|
||||
def test_synthesis_with_document_reference(self):
|
||||
triples = synthesis_triples(
|
||||
self.SYN_URI, self.FOC_URI,
|
||||
|
|
@ -675,23 +714,9 @@ class TestSynthesisTriples:
|
|||
assert doc.o.type == IRI
|
||||
assert doc.o.iri == "urn:trustgraph:question:abc/answer"
|
||||
|
||||
def test_synthesis_document_takes_precedence(self):
|
||||
"""When both document_id and answer_text are provided, document_id wins."""
|
||||
triples = synthesis_triples(
|
||||
self.SYN_URI, self.FOC_URI,
|
||||
answer_text="inline",
|
||||
document_id="urn:doc:123",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc is not None
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content is None
|
||||
|
||||
def test_synthesis_no_content_or_document(self):
|
||||
def test_synthesis_no_document(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert content is None
|
||||
assert doc is None
|
||||
|
||||
|
||||
|
|
@ -723,31 +748,31 @@ class TestDocRagQuestionTriples:
|
|||
class TestDocRagExplorationTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:docrag:test/exploration"
|
||||
Q_URI = "urn:trustgraph:docrag:test"
|
||||
GND_URI = "urn:trustgraph:docrag:test/grounding"
|
||||
|
||||
def test_docrag_exploration_types(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
|
||||
def test_docrag_exploration_generated_by(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI)
|
||||
assert gen.o.iri == self.Q_URI
|
||||
def test_docrag_exploration_derived_from_grounding(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
|
||||
assert derived.o.iri == self.GND_URI
|
||||
|
||||
def test_docrag_exploration_chunk_count(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7)
|
||||
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
|
||||
assert cc.o.value == "7"
|
||||
|
||||
def test_docrag_exploration_without_chunk_ids(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3)
|
||||
chunks = find_triples(triples, TG_SELECTED_CHUNK)
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_docrag_exploration_with_chunk_ids(self):
|
||||
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids)
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids)
|
||||
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
|
||||
assert len(chunks) == 3
|
||||
chunk_uris = {t.o.iri for t in chunks}
|
||||
|
|
@ -770,10 +795,9 @@ class TestDocRagSynthesisTriples:
|
|||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
|
||||
assert derived.o.iri == self.EXP_URI
|
||||
|
||||
def test_docrag_synthesis_with_inline(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer")
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content.o.value == "answer"
|
||||
def test_docrag_synthesis_has_answer_type(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_docrag_synthesis_with_document(self):
|
||||
triples = docrag_synthesis_triples(
|
||||
|
|
@ -781,5 +805,8 @@ class TestDocRagSynthesisTriples:
|
|||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc.o.iri == "urn:doc:ans"
|
||||
content = find_triple(triples, TG_CONTENT, self.SYN_URI)
|
||||
assert content is None
|
||||
|
||||
def test_docrag_synthesis_no_document(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc is None
|
||||
|
|
|
|||
|
|
@ -125,19 +125,15 @@ class TestQuery:
|
|||
assert query.doc_limit == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
async def test_extract_concepts(self):
|
||||
"""Test Query.extract_concepts extracts concepts from query"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock the embed method to return test vectors in batch format
|
||||
# New format: [[[vectors_for_text1]]] - returns first text's vector set
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
# Mock the prompt response with concept lines
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -145,20 +141,62 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What documents are relevant?"
|
||||
result = await query.get_vector(test_query)
|
||||
result = await query.extract_concepts("What is machine learning?")
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": "What is machine learning?"}
|
||||
)
|
||||
assert result == ["machine learning", "artificial intelligence", "data patterns"]
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_concepts_fallback_to_raw_query(self):
|
||||
"""Test Query.extract_concepts falls back to raw query when no concepts extracted"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock empty response
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = await query.extract_concepts("What is ML?")
|
||||
|
||||
assert result == ["What is ML?"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vectors_method(self):
|
||||
"""Test Query.get_vectors method calls embeddings client correctly"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method - returns vectors for each concept
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
concepts = ["machine learning", "data patterns"]
|
||||
result = await query.get_vectors(concepts)
|
||||
|
||||
mock_embeddings_client.embed.assert_called_once_with(concepts)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_method(self):
|
||||
"""Test Query.get_docs method retrieves documents correctly"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
|
@ -170,10 +208,8 @@ class TestQuery:
|
|||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Mock the embedding and document query responses
|
||||
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
# Mock embeddings - one vector per concept
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
# Mock document embeddings returns ChunkMatch objects
|
||||
mock_match1 = MagicMock()
|
||||
|
|
@ -184,7 +220,6 @@ class TestQuery:
|
|||
mock_match2.score = 0.85
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -193,16 +228,16 @@ class TestQuery:
|
|||
doc_limit=15
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
test_query = "Find relevant documents"
|
||||
result = await query.get_docs(test_query)
|
||||
# Call get_docs with concepts list
|
||||
concepts = ["test concept"]
|
||||
result = await query.get_docs(concepts)
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
# Verify embeddings client was called with concepts
|
||||
mock_embeddings_client.embed.assert_called_once_with(concepts)
|
||||
|
||||
# Verify doc embeddings client was called correctly (with extracted vector)
|
||||
# Verify doc embeddings client was called
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=test_vectors,
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=15,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
|
|
@ -218,14 +253,17 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_method(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock embeddings and document embeddings responses
|
||||
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "test concept"
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
mock_match1 = MagicMock()
|
||||
mock_match1.chunk_id = "doc/c3"
|
||||
mock_match1.score = 0.9
|
||||
|
|
@ -234,11 +272,9 @@ class TestQuery:
|
|||
mock_match2.score = 0.8
|
||||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -247,7 +283,6 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
|
|
@ -255,12 +290,18 @@ class TestQuery:
|
|||
doc_limit=10
|
||||
)
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["test query"])
|
||||
# Verify concept extraction was called
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": "test query"}
|
||||
)
|
||||
|
||||
# Verify doc embeddings client was called (with extracted vector)
|
||||
# Verify embeddings called with extracted concepts
|
||||
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
|
||||
|
||||
# Verify doc embeddings client was called
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=test_vectors,
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
|
|
@ -270,23 +311,23 @@ class TestQuery:
|
|||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.document_prompt.call_args
|
||||
assert call_args.kwargs["query"] == "test query"
|
||||
# Documents should be fetched content, not chunk_ids
|
||||
docs = call_args.kwargs["documents"]
|
||||
assert "Relevant document content" in docs
|
||||
assert "Another document" in docs
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method with default parameters"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses (batch format)
|
||||
# Mock concept extraction fallback (empty → raw query)
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = "doc/c5"
|
||||
|
|
@ -294,7 +335,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -302,10 +342,9 @@ class TestQuery:
|
|||
fetch_chunk=mock_fetch_chunk
|
||||
)
|
||||
|
||||
# Call DocumentRag.query with minimal parameters
|
||||
result = await document_rag.query("simple query")
|
||||
|
||||
# Verify default parameters were used (vector extracted from batch)
|
||||
# Verify default parameters were used
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=[[0.1, 0.2]],
|
||||
limit=20, # Default doc_limit
|
||||
|
|
@ -318,7 +357,6 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
"""Test Query.get_docs method with verbose logging"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
|
@ -330,14 +368,13 @@ class TestQuery:
|
|||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Mock responses (batch format)
|
||||
# Mock responses - one vector per concept
|
||||
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = "doc/c6"
|
||||
mock_match.score = 0.88
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -346,14 +383,12 @@ class TestQuery:
|
|||
doc_limit=5
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("verbose test")
|
||||
# Call get_docs with concepts
|
||||
result = await query.get_docs(["verbose test"])
|
||||
|
||||
# Verify calls were made (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify result is tuple of (docs, chunk_ids) with fetched content
|
||||
docs, chunk_ids = result
|
||||
assert "Verbose test doc" in docs
|
||||
assert "doc/c6" in chunk_ids
|
||||
|
|
@ -361,12 +396,14 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method with verbose logging enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses (batch format)
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "verbose query test"
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = "doc/c7"
|
||||
|
|
@ -374,7 +411,6 @@ class TestQuery:
|
|||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
|
||||
# Initialize DocumentRag with verbose=True
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -383,14 +419,11 @@ class TestQuery:
|
|||
verbose=True
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("verbose query test")
|
||||
|
||||
# Verify all clients were called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify prompt client was called with fetched content
|
||||
call_args = mock_prompt_client.document_prompt.call_args
|
||||
assert call_args.kwargs["query"] == "verbose query test"
|
||||
assert "Verbose doc content" in call_args.kwargs["documents"]
|
||||
|
|
@ -400,23 +433,20 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
"""Test Query.get_docs method when no documents are found"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock fetch_chunk (won't be called if no chunk_ids)
|
||||
async def mock_fetch(chunk_id, user):
|
||||
return f"Content for {chunk_id}"
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Mock responses - empty chunk_id list (batch format)
|
||||
# Mock responses - empty results
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -424,30 +454,27 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("query with no results")
|
||||
result = await query.get_docs(["query with no results"])
|
||||
|
||||
# Verify calls were made (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify empty result is returned (tuple of empty lists)
|
||||
assert result == ([], [])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
|
||||
"""Test DocumentRag.query method when no documents are retrieved"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses - no chunk_ids found (batch format)
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "query with no matching docs"
|
||||
|
||||
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -456,10 +483,8 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("query with no matching docs")
|
||||
|
||||
# Verify prompt client was called with empty document list
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="query with no matching docs",
|
||||
documents=[]
|
||||
|
|
@ -468,18 +493,15 @@ class TestQuery:
|
|||
assert result == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose logging"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
async def test_get_vectors_with_verbose(self):
|
||||
"""Test Query.get_vectors method with verbose logging"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method (batch format)
|
||||
expected_vectors = [[0.9, 1.0, 1.1]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -487,40 +509,40 @@ class TestQuery:
|
|||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
result = await query.get_vector("verbose vector test")
|
||||
result = await query.get_vectors(["verbose vector test"])
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
|
||||
|
||||
# Verify result (extracted from batch)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
|
||||
"""Test complete DocumentRag integration with realistic data flow"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock realistic responses (batch format)
|
||||
query_text = "What is machine learning?"
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
|
||||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
mock_embeddings_client.embed.return_value = [query_vectors]
|
||||
mock_matches = []
|
||||
for chunk_id in retrieved_chunk_ids:
|
||||
mock_match = MagicMock()
|
||||
mock_match.chunk_id = chunk_id
|
||||
mock_match.score = 0.9
|
||||
mock_matches.append(mock_match)
|
||||
mock_doc_embeddings_client.query.return_value = mock_matches
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
|
||||
mock_embeddings_client.embed.return_value = query_vectors
|
||||
|
||||
# Each concept query returns some matches
|
||||
mock_matches_1 = [
|
||||
MagicMock(chunk_id="doc/ml1", score=0.9),
|
||||
MagicMock(chunk_id="doc/ml2", score=0.85),
|
||||
]
|
||||
mock_matches_2 = [
|
||||
MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate
|
||||
MagicMock(chunk_id="doc/ml3", score=0.82),
|
||||
]
|
||||
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -529,7 +551,6 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Execute full pipeline
|
||||
result = await document_rag.query(
|
||||
query=query_text,
|
||||
user="research_user",
|
||||
|
|
@ -537,26 +558,69 @@ class TestQuery:
|
|||
doc_limit=25
|
||||
)
|
||||
|
||||
# Verify complete pipeline execution (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([query_text])
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
vector=query_vectors,
|
||||
limit=25,
|
||||
user="research_user",
|
||||
collection="ml_knowledge"
|
||||
# Verify concept extraction
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": query_text}
|
||||
)
|
||||
|
||||
# Verify embeddings called with concepts
|
||||
mock_embeddings_client.embed.assert_called_once_with(
|
||||
["machine learning", "artificial intelligence"]
|
||||
)
|
||||
|
||||
# Verify two per-concept queries were made (25 // 2 = 12 per concept)
|
||||
assert mock_doc_embeddings_client.query.call_count == 2
|
||||
|
||||
# Verify prompt client was called with fetched document content
|
||||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.document_prompt.call_args
|
||||
assert call_args.kwargs["query"] == query_text
|
||||
|
||||
# Verify documents were fetched from chunk_ids
|
||||
# Verify documents were fetched and deduplicated
|
||||
docs = call_args.kwargs["documents"]
|
||||
assert "Machine learning is a subset of artificial intelligence..." in docs
|
||||
assert "ML algorithms learn patterns from data to make predictions..." in docs
|
||||
assert "Common ML techniques include supervised and unsupervised learning..." in docs
|
||||
assert len(docs) == 3 # doc/ml2 deduplicated
|
||||
|
||||
# Verify final result
|
||||
assert result == final_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_deduplicates_across_concepts(self):
|
||||
"""Test that get_docs deduplicates chunks across multiple concepts"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
async def mock_fetch(chunk_id, user):
|
||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
mock_rag.fetch_chunk = mock_fetch
|
||||
|
||||
# Two concepts → two vectors
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# Both queries return overlapping chunks
|
||||
match_a = MagicMock(chunk_id="doc/c1", score=0.9)
|
||||
match_b = MagicMock(chunk_id="doc/c2", score=0.8)
|
||||
match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate
|
||||
mock_doc_embeddings_client.query.side_effect = [
|
||||
[match_a, match_b],
|
||||
[match_c],
|
||||
]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
|
||||
|
||||
assert len(chunk_ids) == 2 # doc/c1 only counted once
|
||||
assert "doc/c1" in chunk_ids
|
||||
assert "doc/c2" in chunk_ids
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class TestGraphRag:
|
|||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -27,7 +27,7 @@ class TestGraphRag:
|
|||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
|
|
@ -45,7 +45,7 @@ class TestGraphRag:
|
|||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -54,7 +54,7 @@ class TestGraphRag:
|
|||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
|
|
@ -73,7 +73,7 @@ class TestQuery:
|
|||
"""Test Query initialization with default parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -81,7 +81,7 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
|
|
@ -96,7 +96,7 @@ class TestQuery:
|
|||
"""Test Query initialization with custom parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
# Initialize Query with custom parameters
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -108,7 +108,7 @@ class TestQuery:
|
|||
max_subgraph_size=2000,
|
||||
max_path_length=3
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
|
|
@ -120,18 +120,16 @@ class TestQuery:
|
|||
assert query.max_path_length == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
async def test_get_vectors_method(self):
|
||||
"""Test Query.get_vectors method calls embeddings client correctly"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method to return test vectors (batch format)
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
|
||||
# Initialize Query
|
||||
# Mock embed to return vectors for a list of concepts
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -139,29 +137,22 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What is the capital of France?"
|
||||
result = await query.get_vector(test_query)
|
||||
concepts = ["machine learning", "neural networks"]
|
||||
result = await query.get_vectors(concepts)
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
mock_embeddings_client.embed.assert_called_once_with(concepts)
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose output"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
async def test_get_vectors_method_with_verbose(self):
|
||||
"""Test Query.get_vectors method with verbose output"""
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method (batch format)
|
||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||
mock_embeddings_client.embed.return_value = [expected_vectors]
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -169,48 +160,87 @@ class TestQuery:
|
|||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "Test query for embeddings"
|
||||
result = await query.get_vector(test_query)
|
||||
result = await query.get_vectors(["test concept"])
|
||||
|
||||
# Verify embeddings client was called correctly (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
|
||||
# Verify result matches expected vectors (extracted from batch)
|
||||
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities_method(self):
|
||||
"""Test Query.get_entities method retrieves entities correctly"""
|
||||
# Create mock GraphRag with clients
|
||||
async def test_extract_concepts(self):
|
||||
"""Test Query.extract_concepts parses LLM response into concept list"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = await query.extract_concepts("What is machine learning?")
|
||||
|
||||
mock_prompt_client.prompt.assert_called_once_with(
|
||||
"extract-concepts",
|
||||
variables={"query": "What is machine learning?"}
|
||||
)
|
||||
assert result == ["machine learning", "neural networks"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_concepts_fallback_to_raw_query(self):
|
||||
"""Test extract_concepts falls back to raw query when LLM returns empty"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
result = await query.extract_concepts("test query")
|
||||
assert result == ["test query"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities_method(self):
|
||||
"""Test Query.get_entities extracts concepts, embeds, and retrieves entities"""
|
||||
mock_rag = MagicMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# Mock the embedding and entity query responses (batch format)
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||
|
||||
# Mock EntityMatch objects with entity as Term-like object
|
||||
# extract_concepts returns empty -> falls back to [query]
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
|
||||
# embed returns one vector set for the single concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
# Mock entity matches
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.type = "i" # IRI type
|
||||
mock_entity1.type = "i"
|
||||
mock_entity1.iri = "entity1"
|
||||
mock_match1 = MagicMock()
|
||||
mock_match1.entity = mock_entity1
|
||||
mock_match1.score = 0.95
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.type = "i" # IRI type
|
||||
mock_entity2.type = "i"
|
||||
mock_entity2.iri = "entity2"
|
||||
mock_match2 = MagicMock()
|
||||
mock_match2.entity = mock_entity2
|
||||
mock_match2.score = 0.85
|
||||
|
||||
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -219,35 +249,23 @@ class TestQuery:
|
|||
entity_limit=25
|
||||
)
|
||||
|
||||
# Call get_entities
|
||||
test_query = "Find related entities"
|
||||
result = await query.get_entities(test_query)
|
||||
entities, concepts = await query.get_entities("Find related entities")
|
||||
|
||||
# Verify embeddings client was called (now expects list)
|
||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||
# Verify embeddings client was called with the fallback concept
|
||||
mock_embeddings_client.embed.assert_called_once_with(["Find related entities"])
|
||||
|
||||
# Verify graph embeddings client was called correctly (with extracted vector)
|
||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||
vector=test_vectors,
|
||||
limit=25,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is list of entity strings
|
||||
assert result == ["entity1", "entity2"]
|
||||
# Verify result
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["Find related entities"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_cached_label(self):
|
||||
"""Test Query.maybe_label method with cached label"""
|
||||
# Create mock GraphRag with label cache
|
||||
mock_rag = MagicMock()
|
||||
# Create mock LRUCacheWithTTL
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = "Entity One Label"
|
||||
mock_rag.label_cache = mock_cache
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -255,32 +273,25 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label with cached entity
|
||||
result = await query.maybe_label("entity1")
|
||||
|
||||
# Verify cached label is returned
|
||||
assert result == "Entity One Label"
|
||||
# Verify cache was checked with proper key format (user:collection:entity)
|
||||
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
"""Test Query.maybe_label method with database label lookup"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple result with label
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Human Readable Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -288,20 +299,18 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("http://example.com/entity")
|
||||
|
||||
# Verify triples client was called correctly
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="http://example.com/entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
# Verify result and cache update with proper key
|
||||
assert result == "Human Readable Label"
|
||||
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||
|
|
@ -309,40 +318,34 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_no_label_found(self):
|
||||
"""Test Query.maybe_label method when no label is found"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock empty result (no label found)
|
||||
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
# Initialize Query
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
|
||||
result = await query.maybe_label("unlabeled_entity")
|
||||
|
||||
# Verify triples client was called
|
||||
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="unlabeled_entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
collection="test_collection",
|
||||
g=""
|
||||
)
|
||||
|
||||
# Verify result is entity itself and cache is updated
|
||||
|
||||
assert result == "unlabeled_entity"
|
||||
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
|
@ -350,29 +353,25 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple results for different query patterns
|
||||
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
|
||||
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent, p=None, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple3], # s=None, p=None, o=ent
|
||||
[mock_triple1], # s=ent
|
||||
[mock_triple2], # p=ent
|
||||
[mock_triple3], # o=ent
|
||||
]
|
||||
|
||||
# Initialize Query
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -380,29 +379,25 @@ class TestQuery:
|
|||
verbose=False,
|
||||
triple_limit=10
|
||||
)
|
||||
|
||||
# Call follow_edges
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify all three query patterns were called
|
||||
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
# Verify query_stream calls
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection", batch_size=20
|
||||
user="test_user", collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
|
||||
# Verify subgraph contains discovered triples
|
||||
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
|
|
@ -413,38 +408,30 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call follow_edges with path_length=0
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
# Verify no queries were made
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph remains empty
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query with small max_subgraph_size
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
|
|
@ -452,23 +439,17 @@ class TestQuery:
|
|||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
)
|
||||
|
||||
# Pre-populate subgraph to exceed limit
|
||||
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
# Call follow_edges
|
||||
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify no queries were made due to size limit
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
|
||||
# Verify subgraph unchanged
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||
# Create mock Query that patches get_entities and follow_edges_batch
|
||||
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
|
|
@ -479,130 +460,119 @@ class TestQuery:
|
|||
max_path_length=1
|
||||
)
|
||||
|
||||
# Mock get_entities to return test entities
|
||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||
# Mock get_entities to return (entities, concepts) tuple
|
||||
query.get_entities = AsyncMock(
|
||||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
# Mock follow_edges_batch to return test triples
|
||||
query.follow_edges_batch = AsyncMock(return_value={
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
})
|
||||
|
||||
# Call get_subgraph
|
||||
result = await query.get_subgraph("test query")
|
||||
subgraph, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
# Verify get_entities was called
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
|
||||
# Verify follow_edges_batch was called with entities and max_path_length
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
# Verify result is list format and contains expected triples
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert ("entity1", "predicate1", "object1") in result
|
||||
assert ("entity2", "predicate2", "object2") in result
|
||||
assert isinstance(subgraph, list)
|
||||
assert len(subgraph) == 2
|
||||
assert ("entity1", "predicate1", "object1") in subgraph
|
||||
assert ("entity2", "predicate2", "object2") in subgraph
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["concept1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph method converts entities to labels"""
|
||||
# Create mock Query
|
||||
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
# Mock get_subgraph to return test triples
|
||||
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
query.get_subgraph = AsyncMock(return_value=test_subgraph)
|
||||
|
||||
# Mock maybe_label to return human-readable labels
|
||||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
# Call get_labelgraph
|
||||
labeled_edges, uri_map = await query.get_labelgraph("test query")
|
||||
|
||||
# Verify get_subgraph was called
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
|
||||
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Verify label triples are filtered out
|
||||
assert len(labeled_edges) == 2 # Label triple should be excluded
|
||||
# Label triples filtered out
|
||||
assert len(labeled_edges) == 2
|
||||
|
||||
# Verify maybe_label was called for non-label triples
|
||||
expected_calls = [
|
||||
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
|
||||
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
|
||||
]
|
||||
# maybe_label called for non-label triples
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
# Verify result contains human-readable labels
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
# Verify uri_map maps labeled edges back to original URIs
|
||||
assert len(uri_map) == 2
|
||||
assert entities == test_entities
|
||||
assert concepts == test_concepts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance"""
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
|
||||
# Mock prompt client responses for two-step process
|
||||
expected_response = "This is the RAG response"
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
# Compute the edge ID for the test edge
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
|
||||
# Create uri_map for the test edge (maps labeled edge ID to original URIs)
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
test_entities = ["http://example.org/subject"]
|
||||
test_concepts = ["test concept"]
|
||||
|
||||
# Mock edge selection response (JSONL format)
|
||||
edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
|
||||
# Configure prompt mock to return different responses based on prompt name
|
||||
# Mock prompt responses for the multi-step process
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return edge_selection_response
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return json.dumps({"id": test_edge_id, "score": 0.9})
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -611,27 +581,20 @@ class TestQuery:
|
|||
verbose=False
|
||||
)
|
||||
|
||||
# We need to patch the Query class's get_labelgraph method
|
||||
original_query_init = Query.__init__
|
||||
# Patch Query.get_labelgraph to return test data
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
|
||||
def mock_query_init(self, *args, **kwargs):
|
||||
original_query_init(self, *args, **kwargs)
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph, test_uri_map
|
||||
return test_labelgraph, test_uri_map, test_entities, test_concepts
|
||||
|
||||
Query.__init__ = mock_query_init
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
|
||||
# Collect provenance emitted via callback
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
try:
|
||||
# Call GraphRag.query with provenance callback
|
||||
response = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
|
|
@ -641,25 +604,22 @@ class TestQuery:
|
|||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Verify response text
|
||||
assert response == expected_response
|
||||
|
||||
# Verify provenance was emitted incrementally (4 events: question, exploration, focus, synthesis)
|
||||
assert len(provenance_events) == 4
|
||||
# 5 events: question, grounding, exploration, focus, synthesis
|
||||
assert len(provenance_events) == 5
|
||||
|
||||
# Verify each event has triples and a URN
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order: question, exploration, focus, synthesis
|
||||
# Verify order
|
||||
assert "question" in provenance_events[0][1]
|
||||
assert "exploration" in provenance_events[1][1]
|
||||
assert "focus" in provenance_events[2][1]
|
||||
assert "synthesis" in provenance_events[3][1]
|
||||
assert "grounding" in provenance_events[1][1]
|
||||
assert "exploration" in provenance_events[2][1]
|
||||
assert "focus" in provenance_events[3][1]
|
||||
assert "synthesis" in provenance_events[4][1]
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
Query.__init__ = original_query_init
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
|
|
|
|||
|
|
@ -75,9 +75,11 @@ from .explainability import (
|
|||
ExplainabilityClient,
|
||||
ExplainEntity,
|
||||
Question,
|
||||
Grounding,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
Reflection,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
EdgeSelection,
|
||||
|
|
|
|||
|
|
@ -18,25 +18,28 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||
TG_THOUGHT = TG + "thought"
|
||||
TG_ACTION = TG + "action"
|
||||
TG_ARGUMENTS = TG + "arguments"
|
||||
TG_OBSERVATION = TG + "observation"
|
||||
TG_ANSWER = TG + "answer"
|
||||
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
|
||||
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
|
||||
|
||||
# Entity types
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_GROUNDING = TG + "Grounding"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
TG_FOCUS = TG + "Focus"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
TG_ANSWER_TYPE = TG + "Answer"
|
||||
TG_REFLECTION_TYPE = TG + "Reflection"
|
||||
TG_THOUGHT_TYPE = TG + "Thought"
|
||||
TG_OBSERVATION_TYPE = TG + "Observation"
|
||||
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
|
||||
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
|
||||
TG_AGENT_QUESTION = TG + "AgentQuestion"
|
||||
|
|
@ -73,12 +76,16 @@ class ExplainEntity:
|
|||
|
||||
if TG_GRAPH_RAG_QUESTION in types or TG_DOC_RAG_QUESTION in types or TG_AGENT_QUESTION in types:
|
||||
return Question.from_triples(uri, triples, types)
|
||||
elif TG_GROUNDING in types:
|
||||
return Grounding.from_triples(uri, triples)
|
||||
elif TG_EXPLORATION in types:
|
||||
return Exploration.from_triples(uri, triples)
|
||||
elif TG_FOCUS in types:
|
||||
return Focus.from_triples(uri, triples)
|
||||
elif TG_SYNTHESIS in types:
|
||||
return Synthesis.from_triples(uri, triples)
|
||||
elif TG_REFLECTION_TYPE in types:
|
||||
return Reflection.from_triples(uri, triples)
|
||||
elif TG_ANALYSIS in types:
|
||||
return Analysis.from_triples(uri, triples)
|
||||
elif TG_CONCLUSION in types:
|
||||
|
|
@ -124,16 +131,38 @@ class Question(ExplainEntity):
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Grounding(ExplainEntity):
|
||||
"""Grounding entity - concept decomposition of the query."""
|
||||
concepts: List[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Grounding":
|
||||
concepts = []
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_CONCEPT:
|
||||
concepts.append(o)
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
entity_type="grounding",
|
||||
concepts=concepts
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Exploration(ExplainEntity):
|
||||
"""Exploration entity - edges/chunks retrieved from the knowledge store."""
|
||||
edge_count: int = 0
|
||||
chunk_count: int = 0
|
||||
entities: List[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Exploration":
|
||||
edge_count = 0
|
||||
chunk_count = 0
|
||||
entities = []
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE_COUNT:
|
||||
|
|
@ -146,12 +175,15 @@ class Exploration(ExplainEntity):
|
|||
chunk_count = int(o)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif p == TG_ENTITY:
|
||||
entities.append(o)
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
entity_type="exploration",
|
||||
edge_count=edge_count,
|
||||
chunk_count=chunk_count
|
||||
chunk_count=chunk_count,
|
||||
entities=entities
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -180,94 +212,104 @@ class Focus(ExplainEntity):
|
|||
@dataclass
|
||||
class Synthesis(ExplainEntity):
|
||||
"""Synthesis entity - the final answer."""
|
||||
content: str = ""
|
||||
document_uri: str = "" # Reference to librarian document
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Synthesis":
|
||||
content = ""
|
||||
document_uri = ""
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_CONTENT:
|
||||
content = o
|
||||
elif p == TG_DOCUMENT:
|
||||
if p == TG_DOCUMENT:
|
||||
document_uri = o
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
entity_type="synthesis",
|
||||
content=content,
|
||||
document_uri=document_uri
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reflection(ExplainEntity):
|
||||
"""Reflection entity - intermediate commentary (Thought or Observation)."""
|
||||
document_uri: str = "" # Reference to content in librarian
|
||||
reflection_type: str = "" # "thought" or "observation"
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Reflection":
|
||||
document_uri = ""
|
||||
reflection_type = ""
|
||||
|
||||
types = [o for s, p, o in triples if p == RDF_TYPE]
|
||||
|
||||
if TG_THOUGHT_TYPE in types:
|
||||
reflection_type = "thought"
|
||||
elif TG_OBSERVATION_TYPE in types:
|
||||
reflection_type = "observation"
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_DOCUMENT:
|
||||
document_uri = o
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
entity_type="reflection",
|
||||
document_uri=document_uri,
|
||||
reflection_type=reflection_type
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Analysis(ExplainEntity):
|
||||
"""Analysis entity - one think/act/observe cycle (Agent only)."""
|
||||
thought: str = ""
|
||||
action: str = ""
|
||||
arguments: str = "" # JSON string
|
||||
observation: str = ""
|
||||
thought_document_uri: str = "" # Reference to thought in librarian
|
||||
observation_document_uri: str = "" # Reference to observation in librarian
|
||||
thought_uri: str = "" # URI of thought sub-entity
|
||||
observation_uri: str = "" # URI of observation sub-entity
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis":
|
||||
thought = ""
|
||||
action = ""
|
||||
arguments = ""
|
||||
observation = ""
|
||||
thought_document_uri = ""
|
||||
observation_document_uri = ""
|
||||
thought_uri = ""
|
||||
observation_uri = ""
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_THOUGHT:
|
||||
thought = o
|
||||
elif p == TG_ACTION:
|
||||
if p == TG_ACTION:
|
||||
action = o
|
||||
elif p == TG_ARGUMENTS:
|
||||
arguments = o
|
||||
elif p == TG_THOUGHT:
|
||||
thought_uri = o
|
||||
elif p == TG_OBSERVATION:
|
||||
observation = o
|
||||
elif p == TG_THOUGHT_DOCUMENT:
|
||||
thought_document_uri = o
|
||||
elif p == TG_OBSERVATION_DOCUMENT:
|
||||
observation_document_uri = o
|
||||
observation_uri = o
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
entity_type="analysis",
|
||||
thought=thought,
|
||||
action=action,
|
||||
arguments=arguments,
|
||||
observation=observation,
|
||||
thought_document_uri=thought_document_uri,
|
||||
observation_document_uri=observation_document_uri
|
||||
thought_uri=thought_uri,
|
||||
observation_uri=observation_uri
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conclusion(ExplainEntity):
|
||||
"""Conclusion entity - final answer (Agent only)."""
|
||||
answer: str = ""
|
||||
document_uri: str = "" # Reference to librarian document
|
||||
|
||||
@classmethod
|
||||
def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Conclusion":
|
||||
answer = ""
|
||||
document_uri = ""
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_ANSWER:
|
||||
answer = o
|
||||
elif p == TG_DOCUMENT:
|
||||
if p == TG_DOCUMENT:
|
||||
document_uri = o
|
||||
|
||||
return cls(
|
||||
uri=uri,
|
||||
entity_type="conclusion",
|
||||
answer=answer,
|
||||
document_uri=document_uri
|
||||
)
|
||||
|
||||
|
|
@ -543,42 +585,29 @@ class ExplainabilityClient:
|
|||
o_label = self.resolve_label(edge.get("o", ""), user, collection)
|
||||
return (s_label, p_label, o_label)
|
||||
|
||||
def fetch_synthesis_content(
|
||||
def fetch_document_content(
|
||||
self,
|
||||
synthesis: Synthesis,
|
||||
document_uri: str,
|
||||
api: Any,
|
||||
user: Optional[str] = None,
|
||||
max_content: int = 10000
|
||||
) -> str:
|
||||
"""
|
||||
Fetch the content for a Synthesis entity.
|
||||
|
||||
If synthesis has inline content, returns that.
|
||||
If synthesis has a document_uri, fetches from librarian with retry.
|
||||
Fetch content from the librarian by document URI.
|
||||
|
||||
Args:
|
||||
synthesis: The Synthesis entity
|
||||
document_uri: The document URI in the librarian
|
||||
api: TrustGraph Api instance for librarian access
|
||||
user: User identifier for librarian
|
||||
max_content: Maximum content length to return
|
||||
|
||||
Returns:
|
||||
The synthesis content as a string
|
||||
The document content as a string
|
||||
"""
|
||||
# If inline content exists, use it
|
||||
if synthesis.content:
|
||||
if len(synthesis.content) > max_content:
|
||||
return synthesis.content[:max_content] + "... [truncated]"
|
||||
return synthesis.content
|
||||
|
||||
# Otherwise fetch from librarian
|
||||
if not synthesis.document_uri:
|
||||
if not document_uri:
|
||||
return ""
|
||||
|
||||
# Extract document ID from URI (e.g., "urn:document:abc123" -> "abc123")
|
||||
doc_id = synthesis.document_uri
|
||||
if doc_id.startswith("urn:document:"):
|
||||
doc_id = doc_id[len("urn:document:"):]
|
||||
doc_id = document_uri
|
||||
|
||||
# Retry fetching from librarian for eventual consistency
|
||||
for attempt in range(self.max_retries):
|
||||
|
|
@ -603,129 +632,6 @@ class ExplainabilityClient:
|
|||
|
||||
return ""
|
||||
|
||||
def fetch_conclusion_content(
|
||||
self,
|
||||
conclusion: Conclusion,
|
||||
api: Any,
|
||||
user: Optional[str] = None,
|
||||
max_content: int = 10000
|
||||
) -> str:
|
||||
"""
|
||||
Fetch the content for a Conclusion entity (Agent final answer).
|
||||
|
||||
If conclusion has inline answer, returns that.
|
||||
If conclusion has a document_uri, fetches from librarian with retry.
|
||||
|
||||
Args:
|
||||
conclusion: The Conclusion entity
|
||||
api: TrustGraph Api instance for librarian access
|
||||
user: User identifier for librarian
|
||||
max_content: Maximum content length to return
|
||||
|
||||
Returns:
|
||||
The conclusion answer as a string
|
||||
"""
|
||||
# If inline answer exists, use it
|
||||
if conclusion.answer:
|
||||
if len(conclusion.answer) > max_content:
|
||||
return conclusion.answer[:max_content] + "... [truncated]"
|
||||
return conclusion.answer
|
||||
|
||||
# Otherwise fetch from librarian
|
||||
if not conclusion.document_uri:
|
||||
return ""
|
||||
|
||||
# Use document URI directly (it's already a full URN)
|
||||
doc_id = conclusion.document_uri
|
||||
|
||||
# Retry fetching from librarian for eventual consistency
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
library = api.library()
|
||||
content_bytes = library.get_document_content(user=user, id=doc_id)
|
||||
|
||||
# Decode as text
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
if len(content) > max_content:
|
||||
return content[:max_content] + "... [truncated]"
|
||||
return content
|
||||
except UnicodeDecodeError:
|
||||
return f"[Binary: {len(content_bytes)} bytes]"
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay)
|
||||
continue
|
||||
return f"[Error fetching content: {e}]"
|
||||
|
||||
return ""
|
||||
|
||||
def fetch_analysis_content(
|
||||
self,
|
||||
analysis: Analysis,
|
||||
api: Any,
|
||||
user: Optional[str] = None,
|
||||
max_content: int = 10000
|
||||
) -> None:
|
||||
"""
|
||||
Fetch thought and observation content for an Analysis entity.
|
||||
|
||||
If analysis has inline content, uses that.
|
||||
If analysis has document URIs, fetches from librarian with retry.
|
||||
Modifies the analysis object in place.
|
||||
|
||||
Args:
|
||||
analysis: The Analysis entity (modified in place)
|
||||
api: TrustGraph Api instance for librarian access
|
||||
user: User identifier for librarian
|
||||
max_content: Maximum content length to return
|
||||
"""
|
||||
# Fetch thought if needed
|
||||
if not analysis.thought and analysis.thought_document_uri:
|
||||
doc_id = analysis.thought_document_uri
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
library = api.library()
|
||||
content_bytes = library.get_document_content(user=user, id=doc_id)
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
if len(content) > max_content:
|
||||
analysis.thought = content[:max_content] + "... [truncated]"
|
||||
else:
|
||||
analysis.thought = content
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
analysis.thought = f"[Binary: {len(content_bytes)} bytes]"
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay)
|
||||
continue
|
||||
analysis.thought = f"[Error fetching thought: {e}]"
|
||||
|
||||
# Fetch observation if needed
|
||||
if not analysis.observation and analysis.observation_document_uri:
|
||||
doc_id = analysis.observation_document_uri
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
library = api.library()
|
||||
content_bytes = library.get_document_content(user=user, id=doc_id)
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
if len(content) > max_content:
|
||||
analysis.observation = content[:max_content] + "... [truncated]"
|
||||
else:
|
||||
analysis.observation = content
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
analysis.observation = f"[Binary: {len(content_bytes)} bytes]"
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay)
|
||||
continue
|
||||
analysis.observation = f"[Error fetching observation: {e}]"
|
||||
|
||||
def fetch_graphrag_trace(
|
||||
self,
|
||||
|
|
@ -739,7 +645,7 @@ class ExplainabilityClient:
|
|||
"""
|
||||
Fetch the complete GraphRAG trace starting from a question URI.
|
||||
|
||||
Follows the provenance chain: Question -> Exploration -> Focus -> Synthesis
|
||||
Follows the provenance chain: Question -> Grounding -> Exploration -> Focus -> Synthesis
|
||||
|
||||
Args:
|
||||
question_uri: The question entity URI
|
||||
|
|
@ -750,13 +656,14 @@ class ExplainabilityClient:
|
|||
max_content: Maximum content length for synthesis
|
||||
|
||||
Returns:
|
||||
Dict with question, exploration, focus, synthesis entities
|
||||
Dict with question, grounding, exploration, focus, synthesis entities
|
||||
"""
|
||||
if graph is None:
|
||||
graph = "urn:graph:retrieval"
|
||||
|
||||
trace = {
|
||||
"question": None,
|
||||
"grounding": None,
|
||||
"exploration": None,
|
||||
"focus": None,
|
||||
"synthesis": None,
|
||||
|
|
@ -768,8 +675,8 @@ class ExplainabilityClient:
|
|||
return trace
|
||||
trace["question"] = question
|
||||
|
||||
# Find exploration: ?exploration prov:wasGeneratedBy question_uri
|
||||
exploration_triples = self.flow.triples_query(
|
||||
# Find grounding: ?grounding prov:wasGeneratedBy question_uri
|
||||
grounding_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_GENERATED_BY,
|
||||
o=question_uri,
|
||||
g=graph,
|
||||
|
|
@ -778,6 +685,30 @@ class ExplainabilityClient:
|
|||
limit=10
|
||||
)
|
||||
|
||||
if grounding_triples:
|
||||
grounding_uris = [
|
||||
extract_term_value(t.get("s", {}))
|
||||
for t in grounding_triples
|
||||
]
|
||||
for gnd_uri in grounding_uris:
|
||||
grounding = self.fetch_entity(gnd_uri, graph, user, collection)
|
||||
if isinstance(grounding, Grounding):
|
||||
trace["grounding"] = grounding
|
||||
break
|
||||
|
||||
if not trace["grounding"]:
|
||||
return trace
|
||||
|
||||
# Find exploration: ?exploration prov:wasDerivedFrom grounding_uri
|
||||
exploration_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=trace["grounding"].uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
|
||||
if exploration_triples:
|
||||
exploration_uris = [
|
||||
extract_term_value(t.get("s", {}))
|
||||
|
|
@ -834,11 +765,6 @@ class ExplainabilityClient:
|
|||
for synth_uri in synthesis_uris:
|
||||
synthesis = self.fetch_entity(synth_uri, graph, user, collection)
|
||||
if isinstance(synthesis, Synthesis):
|
||||
# Fetch content if needed
|
||||
if api and not synthesis.content and synthesis.document_uri:
|
||||
synthesis.content = self.fetch_synthesis_content(
|
||||
synthesis, api, user, max_content
|
||||
)
|
||||
trace["synthesis"] = synthesis
|
||||
break
|
||||
|
||||
|
|
@ -928,11 +854,6 @@ class ExplainabilityClient:
|
|||
for synth_uri in synthesis_uris:
|
||||
synthesis = self.fetch_entity(synth_uri, graph, user, collection)
|
||||
if isinstance(synthesis, Synthesis):
|
||||
# Fetch content if needed
|
||||
if api and not synthesis.content and synthesis.document_uri:
|
||||
synthesis.content = self.fetch_synthesis_content(
|
||||
synthesis, api, user, max_content
|
||||
)
|
||||
trace["synthesis"] = synthesis
|
||||
break
|
||||
|
||||
|
|
@ -978,20 +899,43 @@ class ExplainabilityClient:
|
|||
return trace
|
||||
trace["question"] = question
|
||||
|
||||
# Follow the chain of wasDerivedFrom
|
||||
# Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after
|
||||
current_uri = session_uri
|
||||
is_first = True
|
||||
max_iterations = 50 # Safety limit
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# Find entity derived from current
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=current_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
# First hop uses wasGeneratedBy (entity←activity),
|
||||
# subsequent hops use wasDerivedFrom (entity←entity)
|
||||
if is_first:
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_GENERATED_BY,
|
||||
o=current_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
# Fall back to wasDerivedFrom for backwards compatibility
|
||||
if not derived_triples:
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=current_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
is_first = False
|
||||
else:
|
||||
derived_triples = self.flow.triples_query(
|
||||
p=PROV_WAS_DERIVED_FROM,
|
||||
o=current_uri,
|
||||
g=graph,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=10
|
||||
)
|
||||
|
||||
if not derived_triples:
|
||||
break
|
||||
|
|
@ -1003,19 +947,9 @@ class ExplainabilityClient:
|
|||
entity = self.fetch_entity(derived_uri, graph, user, collection)
|
||||
|
||||
if isinstance(entity, Analysis):
|
||||
# Fetch thought/observation content from librarian if needed
|
||||
if api:
|
||||
self.fetch_analysis_content(
|
||||
entity, api, user=user, max_content=max_content
|
||||
)
|
||||
trace["iterations"].append(entity)
|
||||
current_uri = derived_uri
|
||||
elif isinstance(entity, Conclusion):
|
||||
# Fetch answer content from librarian if needed
|
||||
if api and not entity.answer and entity.document_uri:
|
||||
entity.answer = self.fetch_conclusion_content(
|
||||
entity, api, user=user, max_content=max_content
|
||||
)
|
||||
trace["conclusion"] = entity
|
||||
break
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL, TRIPLE
|
||||
from .. knowledge import Uri, Literal
|
||||
|
||||
|
||||
|
|
@ -22,9 +22,11 @@ def to_value(x):
|
|||
|
||||
|
||||
def from_value(x):
|
||||
"""Convert Uri, Literal, or string to schema Term."""
|
||||
"""Convert Uri, Literal, string, or Term to schema Term."""
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, Term):
|
||||
return x
|
||||
if isinstance(x, Uri):
|
||||
return Term(type=IRI, iri=str(x))
|
||||
elif isinstance(x, Literal):
|
||||
|
|
@ -41,7 +43,7 @@ def from_value(x):
|
|||
class TriplesClient(RequestResponse):
|
||||
async def query(self, s=None, p=None, o=None, limit=20,
|
||||
user="trustgraph", collection="default",
|
||||
timeout=30):
|
||||
timeout=30, g=None):
|
||||
|
||||
resp = await self.request(
|
||||
TriplesQueryRequest(
|
||||
|
|
@ -51,6 +53,7 @@ class TriplesClient(RequestResponse):
|
|||
limit = limit,
|
||||
user = user,
|
||||
collection = collection,
|
||||
g = g,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
|
@ -68,7 +71,7 @@ class TriplesClient(RequestResponse):
|
|||
async def query_stream(self, s=None, p=None, o=None, limit=20,
|
||||
user="trustgraph", collection="default",
|
||||
batch_size=20, timeout=30,
|
||||
batch_callback=None):
|
||||
batch_callback=None, g=None):
|
||||
"""
|
||||
Streaming triple query - calls callback for each batch as it arrives.
|
||||
|
||||
|
|
@ -80,6 +83,8 @@ class TriplesClient(RequestResponse):
|
|||
batch_size: Triples per batch
|
||||
timeout: Request timeout in seconds
|
||||
batch_callback: Async callback(batch, is_final) called for each batch
|
||||
g: Graph filter. ""=default graph only, None=all graphs,
|
||||
or a specific graph IRI.
|
||||
|
||||
Returns:
|
||||
List[Triple]: All triples (flattened) if no callback provided
|
||||
|
|
@ -112,6 +117,7 @@ class TriplesClient(RequestResponse):
|
|||
collection=collection,
|
||||
streaming=True,
|
||||
batch_size=batch_size,
|
||||
g=g,
|
||||
),
|
||||
timeout=timeout,
|
||||
recipient=recipient,
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
triple_limit=int(data.get("triple-limit", 30)),
|
||||
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
||||
max_path_length=int(data.get("max-path-length", 2)),
|
||||
edge_limit=int(data.get("edge-limit", 25)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
|
@ -96,6 +97,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
"triple-limit": obj.triple_limit,
|
||||
"max-subgraph-size": obj.max_subgraph_size,
|
||||
"max-path-length": obj.max_path_length,
|
||||
"edge-limit": obj.edge_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,15 +42,19 @@ from . uris import (
|
|||
agent_uri,
|
||||
# Query-time provenance URIs (GraphRAG)
|
||||
question_uri,
|
||||
grounding_uri,
|
||||
exploration_uri,
|
||||
focus_uri,
|
||||
synthesis_uri,
|
||||
# Agent provenance URIs
|
||||
agent_session_uri,
|
||||
agent_iteration_uri,
|
||||
agent_thought_uri,
|
||||
agent_observation_uri,
|
||||
agent_final_uri,
|
||||
# Document RAG provenance URIs
|
||||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_synthesis_uri,
|
||||
)
|
||||
|
|
@ -74,18 +78,19 @@ from . namespaces import (
|
|||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_CONTENT,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
# Unifying types
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
# Question subtypes (to distinguish retrieval mechanism)
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
|
||||
# Agent provenance predicates
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
# Agent document references
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
# Document reference predicate
|
||||
TG_DOCUMENT,
|
||||
# Named graphs
|
||||
|
|
@ -99,6 +104,7 @@ from . triples import (
|
|||
subgraph_provenance_triples,
|
||||
# Query-time provenance triple builders (GraphRAG)
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
|
|
@ -139,15 +145,19 @@ __all__ = [
|
|||
"agent_uri",
|
||||
# Query-time provenance URIs
|
||||
"question_uri",
|
||||
"grounding_uri",
|
||||
"exploration_uri",
|
||||
"focus_uri",
|
||||
"synthesis_uri",
|
||||
# Agent provenance URIs
|
||||
"agent_session_uri",
|
||||
"agent_iteration_uri",
|
||||
"agent_thought_uri",
|
||||
"agent_observation_uri",
|
||||
"agent_final_uri",
|
||||
# Document RAG provenance URIs
|
||||
"docrag_question_uri",
|
||||
"docrag_grounding_uri",
|
||||
"docrag_exploration_uri",
|
||||
"docrag_synthesis_uri",
|
||||
# Namespaces
|
||||
|
|
@ -164,18 +174,19 @@ __all__ = [
|
|||
# Extraction provenance entity types
|
||||
"TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE",
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
"TG_QUERY", "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_CONTENT",
|
||||
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||
# Explainability entity types
|
||||
"TG_QUESTION", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
||||
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
||||
"TG_ANALYSIS", "TG_CONCLUSION",
|
||||
# Unifying types
|
||||
"TG_ANSWER_TYPE", "TG_REFLECTION_TYPE", "TG_THOUGHT_TYPE", "TG_OBSERVATION_TYPE",
|
||||
# Question subtypes
|
||||
"TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION",
|
||||
# Agent provenance predicates
|
||||
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER",
|
||||
# Agent document references
|
||||
"TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT",
|
||||
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION",
|
||||
# Document reference predicate
|
||||
"TG_DOCUMENT",
|
||||
# Named graphs
|
||||
|
|
@ -186,6 +197,7 @@ __all__ = [
|
|||
"subgraph_provenance_triples",
|
||||
# Query-time provenance triple builders (GraphRAG)
|
||||
"question_triples",
|
||||
"grounding_triples",
|
||||
"exploration_triples",
|
||||
"focus_triples",
|
||||
"synthesis_triples",
|
||||
|
|
|
|||
|
|
@ -15,10 +15,11 @@ from .. schema import Triple, Term, IRI, LITERAL
|
|||
|
||||
from . namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
|
@ -73,12 +74,13 @@ def agent_session_triples(
|
|||
|
||||
def agent_iteration_triples(
|
||||
iteration_uri: str,
|
||||
parent_uri: str,
|
||||
thought: str = "",
|
||||
question_uri: Optional[str] = None,
|
||||
previous_uri: Optional[str] = None,
|
||||
action: str = "",
|
||||
arguments: Dict[str, Any] = None,
|
||||
observation: str = "",
|
||||
thought_uri: Optional[str] = None,
|
||||
thought_document_id: Optional[str] = None,
|
||||
observation_uri: Optional[str] = None,
|
||||
observation_document_id: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
|
|
@ -86,19 +88,22 @@ def agent_iteration_triples(
|
|||
|
||||
Creates:
|
||||
- Entity declaration with tg:Analysis type
|
||||
- wasDerivedFrom link to parent (previous iteration or session)
|
||||
- Thought, action, arguments, and observation data
|
||||
- Document references for thought/observation when stored in librarian
|
||||
- wasGeneratedBy link to question (if first iteration)
|
||||
- wasDerivedFrom link to previous iteration (if not first)
|
||||
- Action and arguments metadata
|
||||
- Thought sub-entity (tg:Reflection, tg:Thought) with librarian document
|
||||
- Observation sub-entity (tg:Reflection, tg:Observation) with librarian document
|
||||
|
||||
Args:
|
||||
iteration_uri: URI of this iteration (from agent_iteration_uri)
|
||||
parent_uri: URI of the parent (previous iteration or session)
|
||||
thought: The agent's reasoning/thought (used if thought_document_id not provided)
|
||||
question_uri: URI of the question activity (for first iteration)
|
||||
previous_uri: URI of the previous iteration (for subsequent iterations)
|
||||
action: The tool/action name
|
||||
arguments: Arguments passed to the tool (will be JSON-encoded)
|
||||
observation: The result/observation from the tool (used if observation_document_id not provided)
|
||||
thought_document_id: Optional document URI for thought in librarian (preferred)
|
||||
observation_document_id: Optional document URI for observation in librarian (preferred)
|
||||
thought_uri: URI for the thought sub-entity
|
||||
thought_document_id: Document URI for thought in librarian
|
||||
observation_uri: URI for the observation sub-entity
|
||||
observation_document_id: Document URI for observation in librarian
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
|
|
@ -110,45 +115,70 @@ def agent_iteration_triples(
|
|||
_triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)),
|
||||
_triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")),
|
||||
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
|
||||
_triple(iteration_uri, TG_ACTION, _literal(action)),
|
||||
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
|
||||
]
|
||||
|
||||
# Thought: use document reference or inline
|
||||
if thought_document_id:
|
||||
triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id)))
|
||||
elif thought:
|
||||
triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought)))
|
||||
if question_uri:
|
||||
triples.append(
|
||||
_triple(iteration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri))
|
||||
)
|
||||
elif previous_uri:
|
||||
triples.append(
|
||||
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri))
|
||||
)
|
||||
|
||||
# Observation: use document reference or inline
|
||||
if observation_document_id:
|
||||
triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id)))
|
||||
elif observation:
|
||||
triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation)))
|
||||
# Thought sub-entity
|
||||
if thought_uri:
|
||||
triples.extend([
|
||||
_triple(iteration_uri, TG_THOUGHT, _iri(thought_uri)),
|
||||
_triple(thought_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)),
|
||||
_triple(thought_uri, RDF_TYPE, _iri(TG_THOUGHT_TYPE)),
|
||||
_triple(thought_uri, RDFS_LABEL, _literal("Thought")),
|
||||
_triple(thought_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)),
|
||||
])
|
||||
if thought_document_id:
|
||||
triples.append(
|
||||
_triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id))
|
||||
)
|
||||
|
||||
# Observation sub-entity
|
||||
if observation_uri:
|
||||
triples.extend([
|
||||
_triple(iteration_uri, TG_OBSERVATION, _iri(observation_uri)),
|
||||
_triple(observation_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)),
|
||||
_triple(observation_uri, RDF_TYPE, _iri(TG_OBSERVATION_TYPE)),
|
||||
_triple(observation_uri, RDFS_LABEL, _literal("Observation")),
|
||||
_triple(observation_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)),
|
||||
])
|
||||
if observation_document_id:
|
||||
triples.append(
|
||||
_triple(observation_uri, TG_DOCUMENT, _iri(observation_document_id))
|
||||
)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def agent_final_triples(
|
||||
final_uri: str,
|
||||
parent_uri: str,
|
||||
answer: str = "",
|
||||
question_uri: Optional[str] = None,
|
||||
previous_uri: Optional[str] = None,
|
||||
document_id: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for an agent final answer (Conclusion).
|
||||
|
||||
Creates:
|
||||
- Entity declaration with tg:Conclusion type
|
||||
- wasDerivedFrom link to parent (last iteration or session)
|
||||
- Either document reference (if document_id provided) or inline answer
|
||||
- Entity declaration with tg:Conclusion and tg:Answer types
|
||||
- wasGeneratedBy link to question (if no iterations)
|
||||
- wasDerivedFrom link to last iteration (if iterations exist)
|
||||
- Document reference to librarian
|
||||
|
||||
Args:
|
||||
final_uri: URI of the final answer (from agent_final_uri)
|
||||
parent_uri: URI of the parent (last iteration or session if no iterations)
|
||||
answer: The final answer text (used if document_id not provided)
|
||||
document_id: Optional document URI in librarian (preferred)
|
||||
question_uri: URI of the question activity (if no iterations)
|
||||
previous_uri: URI of the last iteration (if iterations exist)
|
||||
document_id: Librarian document ID for the answer content
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
|
|
@ -156,15 +186,20 @@ def agent_final_triples(
|
|||
triples = [
|
||||
_triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)),
|
||||
_triple(final_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
|
||||
_triple(final_uri, RDFS_LABEL, _literal("Conclusion")),
|
||||
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)),
|
||||
]
|
||||
|
||||
if question_uri:
|
||||
triples.append(
|
||||
_triple(final_uri, PROV_WAS_GENERATED_BY, _iri(question_uri))
|
||||
)
|
||||
elif previous_uri:
|
||||
triples.append(
|
||||
_triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri))
|
||||
)
|
||||
|
||||
if document_id:
|
||||
# Store reference to document in librarian (as IRI)
|
||||
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
|
||||
elif answer:
|
||||
# Fallback: store inline answer
|
||||
triples.append(_triple(final_uri, TG_ANSWER, _literal(answer)))
|
||||
|
||||
return triples
|
||||
|
|
|
|||
|
|
@ -60,11 +60,12 @@ TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength"
|
|||
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY = TG + "query"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_DOCUMENT = TG + "document" # Reference to document in librarian
|
||||
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
|
|
@ -79,27 +80,29 @@ TG_SUBGRAPH_TYPE = TG + "Subgraph"
|
|||
|
||||
# Explainability entity types (shared)
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_GROUNDING = TG + "Grounding"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
TG_FOCUS = TG + "Focus"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
|
||||
# Unifying types for answer and intermediate commentary
|
||||
TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion)
|
||||
TG_REFLECTION_TYPE = TG + "Reflection" # Intermediate commentary (Thought, Observation)
|
||||
TG_THOUGHT_TYPE = TG + "Thought" # Agent reasoning
|
||||
TG_OBSERVATION_TYPE = TG + "Observation" # Agent tool result
|
||||
|
||||
# Question subtypes (to distinguish retrieval mechanism)
|
||||
TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion"
|
||||
TG_DOC_RAG_QUESTION = TG + "DocRagQuestion"
|
||||
TG_AGENT_QUESTION = TG + "AgentQuestion"
|
||||
|
||||
# Agent provenance predicates
|
||||
TG_THOUGHT = TG + "thought"
|
||||
TG_THOUGHT = TG + "thought" # Links iteration to thought sub-entity
|
||||
TG_ACTION = TG + "action"
|
||||
TG_ARGUMENTS = TG + "arguments"
|
||||
TG_OBSERVATION = TG + "observation"
|
||||
TG_ANSWER = TG + "answer"
|
||||
|
||||
# Agent document references (for librarian storage)
|
||||
TG_THOUGHT_DOCUMENT = TG + "thoughtDocument"
|
||||
TG_OBSERVATION_DOCUMENT = TG + "observationDocument"
|
||||
TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity
|
||||
|
||||
# Named graph URIs for RDF datasets
|
||||
# These separate different types of data while keeping them in the same collection
|
||||
|
|
|
|||
|
|
@ -20,12 +20,15 @@ from . namespaces import (
|
|||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_CONTENT,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
# Unifying types
|
||||
TG_ANSWER_TYPE,
|
||||
# Question subtypes
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||||
)
|
||||
|
|
@ -347,35 +350,78 @@ def question_triples(
|
|||
]
|
||||
|
||||
|
||||
def grounding_triples(
|
||||
grounding_uri: str,
|
||||
question_uri: str,
|
||||
concepts: List[str],
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a grounding entity (concept decomposition of query).
|
||||
|
||||
Creates:
|
||||
- Entity declaration for grounding
|
||||
- wasGeneratedBy link to question
|
||||
- Concept literals for each extracted concept
|
||||
|
||||
Args:
|
||||
grounding_uri: URI of the grounding entity (from grounding_uri)
|
||||
question_uri: URI of the parent question
|
||||
concepts: List of concept strings extracted from the query
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(grounding_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(grounding_uri, RDF_TYPE, _iri(TG_GROUNDING)),
|
||||
_triple(grounding_uri, RDFS_LABEL, _literal("Grounding")),
|
||||
_triple(grounding_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)),
|
||||
]
|
||||
|
||||
for concept in concepts:
|
||||
triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept)))
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def exploration_triples(
|
||||
exploration_uri: str,
|
||||
question_uri: str,
|
||||
grounding_uri: str,
|
||||
edge_count: int,
|
||||
entities: Optional[List[str]] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for an exploration entity (all edges retrieved from subgraph).
|
||||
|
||||
Creates:
|
||||
- Entity declaration for exploration
|
||||
- wasGeneratedBy link to question
|
||||
- wasDerivedFrom link to grounding
|
||||
- Edge count metadata
|
||||
- Entity IRIs for each seed entity
|
||||
|
||||
Args:
|
||||
exploration_uri: URI of the exploration entity (from exploration_uri)
|
||||
question_uri: URI of the parent question
|
||||
grounding_uri: URI of the parent grounding entity
|
||||
edge_count: Number of edges retrieved
|
||||
entities: Optional list of seed entity URIs
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
return [
|
||||
triples = [
|
||||
_triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)),
|
||||
_triple(exploration_uri, RDFS_LABEL, _literal("Exploration")),
|
||||
_triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)),
|
||||
_triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)),
|
||||
_triple(exploration_uri, TG_EDGE_COUNT, _literal(edge_count)),
|
||||
]
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
triples.append(_triple(exploration_uri, TG_ENTITY, _iri(entity)))
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def _quoted_triple(s: str, p: str, o: str) -> Term:
|
||||
"""Create a quoted triple term (RDF-star) from string values."""
|
||||
|
|
@ -454,22 +500,20 @@ def focus_triples(
|
|||
def synthesis_triples(
|
||||
synthesis_uri: str,
|
||||
focus_uri: str,
|
||||
answer_text: str = "",
|
||||
document_id: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a synthesis entity (final answer text).
|
||||
Build triples for a synthesis entity (final answer).
|
||||
|
||||
Creates:
|
||||
- Entity declaration for synthesis
|
||||
- Entity declaration for synthesis with tg:Answer type
|
||||
- wasDerivedFrom link to focus
|
||||
- Either document reference (if document_id provided) or inline content
|
||||
- Document reference to librarian
|
||||
|
||||
Args:
|
||||
synthesis_uri: URI of the synthesis entity (from synthesis_uri)
|
||||
focus_uri: URI of the parent focus entity
|
||||
answer_text: The synthesized answer text (used if no document_id)
|
||||
document_id: Optional librarian document ID (preferred over inline content)
|
||||
document_id: Librarian document ID for the answer content
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
|
|
@ -477,16 +521,13 @@ def synthesis_triples(
|
|||
triples = [
|
||||
_triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)),
|
||||
_triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
|
||||
_triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")),
|
||||
_triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(focus_uri)),
|
||||
]
|
||||
|
||||
if document_id:
|
||||
# Store reference to document in librarian (as IRI)
|
||||
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
|
||||
elif answer_text:
|
||||
# Fallback: store inline content
|
||||
triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text)))
|
||||
|
||||
return triples
|
||||
|
||||
|
|
@ -533,7 +574,7 @@ def docrag_question_triples(
|
|||
|
||||
def docrag_exploration_triples(
|
||||
exploration_uri: str,
|
||||
question_uri: str,
|
||||
grounding_uri: str,
|
||||
chunk_count: int,
|
||||
chunk_ids: Optional[List[str]] = None,
|
||||
) -> List[Triple]:
|
||||
|
|
@ -542,12 +583,12 @@ def docrag_exploration_triples(
|
|||
|
||||
Creates:
|
||||
- Entity declaration with tg:Exploration type
|
||||
- wasGeneratedBy link to question
|
||||
- wasDerivedFrom link to grounding
|
||||
- Chunk count and optional chunk references
|
||||
|
||||
Args:
|
||||
exploration_uri: URI of the exploration entity
|
||||
question_uri: URI of the parent question
|
||||
grounding_uri: URI of the parent grounding entity
|
||||
chunk_count: Number of chunks retrieved
|
||||
chunk_ids: Optional list of chunk URIs/IDs
|
||||
|
||||
|
|
@ -558,7 +599,7 @@ def docrag_exploration_triples(
|
|||
_triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)),
|
||||
_triple(exploration_uri, RDFS_LABEL, _literal("Exploration")),
|
||||
_triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)),
|
||||
_triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)),
|
||||
_triple(exploration_uri, TG_CHUNK_COUNT, _literal(chunk_count)),
|
||||
]
|
||||
|
||||
|
|
@ -573,22 +614,20 @@ def docrag_exploration_triples(
|
|||
def docrag_synthesis_triples(
|
||||
synthesis_uri: str,
|
||||
exploration_uri: str,
|
||||
answer_text: str = "",
|
||||
document_id: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a document RAG synthesis entity (final answer).
|
||||
|
||||
Creates:
|
||||
- Entity declaration with tg:Synthesis type
|
||||
- Entity declaration with tg:Synthesis and tg:Answer types
|
||||
- wasDerivedFrom link to exploration (skips focus step)
|
||||
- Either document reference or inline content
|
||||
- Document reference to librarian
|
||||
|
||||
Args:
|
||||
synthesis_uri: URI of the synthesis entity
|
||||
exploration_uri: URI of the parent exploration entity
|
||||
answer_text: The synthesized answer text (used if no document_id)
|
||||
document_id: Optional librarian document ID (preferred over inline content)
|
||||
document_id: Librarian document ID for the answer content
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
|
|
@ -596,13 +635,12 @@ def docrag_synthesis_triples(
|
|||
triples = [
|
||||
_triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)),
|
||||
_triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)),
|
||||
_triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")),
|
||||
_triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
if document_id:
|
||||
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
|
||||
elif answer_text:
|
||||
triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text)))
|
||||
|
||||
return triples
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ def agent_uri(component_name: str) -> str:
|
|||
#
|
||||
# Terminology:
|
||||
# Question - What was asked, the anchor for everything
|
||||
# Grounding - Decomposing the question into concepts
|
||||
# Exploration - Casting wide, what do we know about this space
|
||||
# Focus - Closing down, what's actually relevant here
|
||||
# Synthesis - Weaving the relevant pieces into an answer
|
||||
|
|
@ -87,6 +88,19 @@ def question_uri(session_id: str = None) -> str:
|
|||
return f"urn:trustgraph:question:{session_id}"
|
||||
|
||||
|
||||
def grounding_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a grounding entity (concept decomposition of query).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID (same as question_uri).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:grounding:{uuid}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:grounding:{session_id}"
|
||||
|
||||
|
||||
def exploration_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for an exploration entity (edges retrieved from subgraph).
|
||||
|
|
@ -173,6 +187,34 @@ def agent_iteration_uri(session_id: str, iteration_num: int) -> str:
|
|||
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}"
|
||||
|
||||
|
||||
def agent_thought_uri(session_id: str, iteration_num: int) -> str:
|
||||
"""
|
||||
Generate URI for an agent thought sub-entity.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
iteration_num: 1-based iteration number.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:agent:{uuid}/i{num}/thought
|
||||
"""
|
||||
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
||||
|
||||
|
||||
def agent_observation_uri(session_id: str, iteration_num: int) -> str:
|
||||
"""
|
||||
Generate URI for an agent observation sub-entity.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
iteration_num: 1-based iteration number.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:agent:{uuid}/i{num}/observation
|
||||
"""
|
||||
return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
||||
|
||||
|
||||
def agent_final_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for an agent final answer.
|
||||
|
|
@ -205,6 +247,19 @@ def docrag_question_uri(session_id: str = None) -> str:
|
|||
return f"urn:trustgraph:docrag:{session_id}"
|
||||
|
||||
|
||||
def docrag_grounding_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG grounding entity (concept decomposition).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:docrag:{uuid}/grounding
|
||||
"""
|
||||
return f"urn:trustgraph:docrag:{session_id}/grounding"
|
||||
|
||||
|
||||
def docrag_exploration_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG exploration entity (chunks retrieved).
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ from . namespaces import (
|
|||
TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL,
|
||||
TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH,
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
TG_CONCEPT, TG_ENTITY, TG_GROUNDING,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -80,6 +82,11 @@ TG_CLASS_LABELS = [
|
|||
_label_triple(TG_PAGE_TYPE, "Page"),
|
||||
_label_triple(TG_CHUNK_TYPE, "Chunk"),
|
||||
_label_triple(TG_SUBGRAPH_TYPE, "Subgraph"),
|
||||
_label_triple(TG_GROUNDING, "Grounding"),
|
||||
_label_triple(TG_ANSWER_TYPE, "Answer"),
|
||||
_label_triple(TG_REFLECTION_TYPE, "Reflection"),
|
||||
_label_triple(TG_THOUGHT_TYPE, "Thought"),
|
||||
_label_triple(TG_OBSERVATION_TYPE, "Observation"),
|
||||
]
|
||||
|
||||
# TrustGraph predicate labels
|
||||
|
|
@ -100,6 +107,8 @@ TG_PREDICATE_LABELS = [
|
|||
_label_triple(TG_SOURCE_TEXT, "source text"),
|
||||
_label_triple(TG_SOURCE_CHAR_OFFSET, "source character offset"),
|
||||
_label_triple(TG_SOURCE_CHAR_LENGTH, "source character length"),
|
||||
_label_triple(TG_CONCEPT, "concept"),
|
||||
_label_triple(TG_ENTITY, "entity"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class GraphRagQuery:
|
|||
triple_limit: int = 0
|
||||
max_subgraph_size: int = 0
|
||||
max_path_length: int = 0
|
||||
edge_limit: int = 0
|
||||
streaming: bool = False
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -202,16 +202,17 @@ def question_explainable(
|
|||
|
||||
elif isinstance(entity, Analysis):
|
||||
print(f"\n [iteration] {prov_id}", file=sys.stderr)
|
||||
if entity.thought:
|
||||
thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought
|
||||
print(f" Thought: {thought_short}", file=sys.stderr)
|
||||
if entity.action:
|
||||
print(f" Action: {entity.action}", file=sys.stderr)
|
||||
if entity.thought_uri:
|
||||
print(f" Thought: {entity.thought_uri}", file=sys.stderr)
|
||||
if entity.observation_uri:
|
||||
print(f" Observation: {entity.observation_uri}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Conclusion):
|
||||
print(f"\n [conclusion] {prov_id}", file=sys.stderr)
|
||||
if entity.answer:
|
||||
print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr)
|
||||
if entity.document_uri:
|
||||
print(f" Document: {entity.document_uri}", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from trustgraph.api import (
|
|||
RAGChunk,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Grounding,
|
||||
Exploration,
|
||||
Synthesis,
|
||||
)
|
||||
|
|
@ -68,6 +69,12 @@ def question_explainable(
|
|||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Grounding):
|
||||
print(f"\n [grounding] {prov_id}", file=sys.stderr)
|
||||
if entity.concepts:
|
||||
for concept in entity.concepts:
|
||||
print(f" Concept: {concept}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Exploration):
|
||||
print(f"\n [exploration] {prov_id}", file=sys.stderr)
|
||||
if entity.chunk_count:
|
||||
|
|
@ -75,8 +82,8 @@ def question_explainable(
|
|||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
if entity.content:
|
||||
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
|
||||
if entity.document_uri:
|
||||
print(f" Document: {entity.document_uri}", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from trustgraph.api import (
|
|||
RAGChunk,
|
||||
ProvenanceEvent,
|
||||
Question,
|
||||
Grounding,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
|
|
@ -31,11 +32,13 @@ default_max_path_length = 2
|
|||
# Provenance predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_CONTENT = TG + "content"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_CONTAINS = TG + "contains"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
|
|
@ -47,6 +50,8 @@ def _get_event_type(prov_id):
|
|||
"""Extract event type from provenance_id"""
|
||||
if "question" in prov_id:
|
||||
return "question"
|
||||
elif "grounding" in prov_id:
|
||||
return "grounding"
|
||||
elif "exploration" in prov_id:
|
||||
return "exploration"
|
||||
elif "focus" in prov_id:
|
||||
|
|
@ -68,8 +73,16 @@ def _format_provenance_details(event_type, triples):
|
|||
elif p == PROV_STARTED_AT_TIME:
|
||||
lines.append(f" Time: {o}")
|
||||
|
||||
elif event_type == "grounding":
|
||||
# Show extracted concepts
|
||||
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
|
||||
if concepts:
|
||||
lines.append(f" Concepts: {len(concepts)}")
|
||||
for concept in concepts:
|
||||
lines.append(f" - {concept}")
|
||||
|
||||
elif event_type == "exploration":
|
||||
# Show edge count
|
||||
# Show edge count (seed entities resolved separately with labels)
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE_COUNT:
|
||||
lines.append(f" Edges explored: {o}")
|
||||
|
|
@ -85,10 +98,10 @@ def _format_provenance_details(event_type, triples):
|
|||
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
|
||||
|
||||
elif event_type == "synthesis":
|
||||
# Show content length (not full content - it's already streamed)
|
||||
# Show document reference (content already streamed)
|
||||
for s, p, o in triples:
|
||||
if p == TG_CONTENT:
|
||||
lines.append(f" Synthesis length: {len(o)} chars")
|
||||
if p == TG_DOCUMENT:
|
||||
lines.append(f" Document: {o}")
|
||||
|
||||
return lines
|
||||
|
||||
|
|
@ -542,6 +555,18 @@ async def _question_explainable(
|
|||
for line in details:
|
||||
print(line, file=sys.stderr)
|
||||
|
||||
# For exploration events, resolve entity labels
|
||||
if event_type == "exploration":
|
||||
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
|
||||
if entity_iris:
|
||||
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
|
||||
for iri in entity_iris:
|
||||
label = await _query_label(
|
||||
ws_url, flow_id, iri, user, collection,
|
||||
label_cache, debug=debug
|
||||
)
|
||||
print(f" - {label}", file=sys.stderr)
|
||||
|
||||
# For focus events, query each edge selection for details
|
||||
if event_type == "focus":
|
||||
for s, p, o in triples:
|
||||
|
|
@ -660,10 +685,22 @@ def _question_explainable_api(
|
|||
if entity.timestamp:
|
||||
print(f" Time: {entity.timestamp}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Grounding):
|
||||
print(f"\n [grounding] {prov_id}", file=sys.stderr)
|
||||
if entity.concepts:
|
||||
print(f" Concepts: {len(entity.concepts)}", file=sys.stderr)
|
||||
for concept in entity.concepts:
|
||||
print(f" - {concept}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Exploration):
|
||||
print(f"\n [exploration] {prov_id}", file=sys.stderr)
|
||||
if entity.edge_count:
|
||||
print(f" Edges explored: {entity.edge_count}", file=sys.stderr)
|
||||
if entity.entities:
|
||||
print(f" Seed entities: {len(entity.entities)}", file=sys.stderr)
|
||||
for ent in entity.entities:
|
||||
label = explain_client.resolve_label(ent, user, collection)
|
||||
print(f" - {label}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Focus):
|
||||
print(f"\n [focus] {prov_id}", file=sys.stderr)
|
||||
|
|
@ -691,8 +728,8 @@ def _question_explainable_api(
|
|||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
if entity.content:
|
||||
print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr)
|
||||
if entity.document_uri:
|
||||
print(f" Document: {entity.document_uri}", file=sys.stderr)
|
||||
|
||||
else:
|
||||
if debug:
|
||||
|
|
@ -848,7 +885,7 @@ def main():
|
|||
parser.add_argument(
|
||||
'-x', '--explainable',
|
||||
action='store_true',
|
||||
help='Show provenance events: Question, Exploration, Focus, Synthesis (implies streaming)'
|
||||
help='Show provenance events: Question, Grounding, Exploration, Focus, Synthesis (implies streaming)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ from ... schema import librarian_request_queue, librarian_response_queue
|
|||
from trustgraph.provenance import (
|
||||
agent_session_uri,
|
||||
agent_iteration_uri,
|
||||
agent_thought_uri,
|
||||
agent_observation_uri,
|
||||
agent_final_uri,
|
||||
agent_session_triples,
|
||||
agent_iteration_triples,
|
||||
|
|
@ -624,11 +626,13 @@ class Processor(AgentService):
|
|||
|
||||
# Emit final answer provenance triples
|
||||
final_uri = agent_final_uri(session_id)
|
||||
# Parent is last iteration, or session if no iterations
|
||||
# No iterations: link to question; otherwise: link to last iteration
|
||||
if iteration_num > 1:
|
||||
parent_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
final_question_uri = None
|
||||
final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
parent_uri = session_uri
|
||||
final_question_uri = session_uri
|
||||
final_previous_uri = None
|
||||
|
||||
# Save answer to librarian
|
||||
answer_doc_id = None
|
||||
|
|
@ -648,8 +652,9 @@ class Processor(AgentService):
|
|||
|
||||
final_triples = set_graph(
|
||||
agent_final_triples(
|
||||
final_uri, parent_uri,
|
||||
answer="" if answer_doc_id else f,
|
||||
final_uri,
|
||||
question_uri=final_question_uri,
|
||||
previous_uri=final_previous_uri,
|
||||
document_id=answer_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
@ -707,11 +712,13 @@ class Processor(AgentService):
|
|||
|
||||
# Emit iteration provenance triples
|
||||
iteration_uri = agent_iteration_uri(session_id, iteration_num)
|
||||
# Parent is previous iteration, or session if this is first iteration
|
||||
# First iteration links to question, subsequent to previous
|
||||
if iteration_num > 1:
|
||||
parent_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
iter_question_uri = None
|
||||
iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
parent_uri = session_uri
|
||||
iter_question_uri = session_uri
|
||||
iter_previous_uri = None
|
||||
|
||||
# Save thought to librarian
|
||||
thought_doc_id = None
|
||||
|
|
@ -745,15 +752,19 @@ class Processor(AgentService):
|
|||
logger.warning(f"Failed to save observation to librarian: {e}")
|
||||
observation_doc_id = None
|
||||
|
||||
thought_entity_uri = agent_thought_uri(session_id, iteration_num)
|
||||
observation_entity_uri = agent_observation_uri(session_id, iteration_num)
|
||||
|
||||
iter_triples = set_graph(
|
||||
agent_iteration_triples(
|
||||
iteration_uri,
|
||||
parent_uri,
|
||||
thought="" if thought_doc_id else act.thought,
|
||||
question_uri=iter_question_uri,
|
||||
previous_uri=iter_previous_uri,
|
||||
action=act.name,
|
||||
arguments=act.arguments,
|
||||
observation="" if observation_doc_id else act.observation,
|
||||
thought_uri=thought_entity_uri if thought_doc_id else None,
|
||||
thought_document_id=thought_doc_id,
|
||||
observation_uri=observation_entity_uri if observation_doc_id else None,
|
||||
observation_document_id=observation_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@ from datetime import datetime
|
|||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_synthesis_uri,
|
||||
docrag_question_triples,
|
||||
grounding_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_synthesis_triples,
|
||||
set_graph,
|
||||
|
|
@ -33,39 +35,79 @@ class Query:
|
|||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
|
||||
async def get_vector(self, query):
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
||||
# Fallback to raw query if no concepts extracted
|
||||
if not concepts:
|
||||
concepts = [query]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Extracted concepts: {concepts}")
|
||||
|
||||
return concepts
|
||||
|
||||
async def get_vectors(self, concepts):
|
||||
"""Compute embeddings for a list of concepts."""
|
||||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed([query])
|
||||
qembeds = await self.rag.embeddings_client.embed(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Embeddings computed")
|
||||
|
||||
# Return the vector set for the first (only) text
|
||||
return qembeds[0] if qembeds else []
|
||||
return qembeds
|
||||
|
||||
async def get_docs(self, query):
|
||||
async def get_docs(self, concepts):
|
||||
"""
|
||||
Get documents (chunks) matching the query.
|
||||
Get documents (chunks) matching the extracted concepts.
|
||||
|
||||
Returns:
|
||||
tuple: (docs, chunk_ids) where:
|
||||
- docs: list of document content strings
|
||||
- chunk_ids: list of chunk IDs that were successfully fetched
|
||||
"""
|
||||
vectors = await self.get_vector(query)
|
||||
vectors = await self.get_vectors(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting chunks from embeddings store...")
|
||||
|
||||
# Get chunk matches from embeddings store
|
||||
chunk_matches = await self.rag.doc_embeddings_client.query(
|
||||
vector=vectors, limit=self.doc_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
# Query chunk matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.doc_limit // len(vectors)
|
||||
)
|
||||
|
||||
async def query_concept(vec):
|
||||
return await self.rag.doc_embeddings_client.query(
|
||||
vector=vec, limit=per_concept_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[query_concept(v) for v in vectors]
|
||||
)
|
||||
|
||||
# Deduplicate chunk matches by chunk_id
|
||||
seen = set()
|
||||
chunk_matches = []
|
||||
for matches in results:
|
||||
for match in matches:
|
||||
if match.chunk_id and match.chunk_id not in seen:
|
||||
seen.add(match.chunk_id)
|
||||
chunk_matches.append(match)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
|
||||
|
||||
|
|
@ -133,6 +175,7 @@ class DocumentRag:
|
|||
# Generate explainability URIs upfront
|
||||
session_id = str(uuid.uuid4())
|
||||
q_uri = docrag_question_uri(session_id)
|
||||
gnd_uri = docrag_grounding_uri(session_id)
|
||||
exp_uri = docrag_exploration_uri(session_id)
|
||||
syn_uri = docrag_synthesis_uri(session_id)
|
||||
|
||||
|
|
@ -151,12 +194,23 @@ class DocumentRag:
|
|||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
docs, chunk_ids = await q.get_docs(query)
|
||||
# Extract concepts from query (grounding step)
|
||||
concepts = await q.extract_concepts(query)
|
||||
|
||||
# Emit grounding explainability after concept extraction
|
||||
if explain_callback:
|
||||
gnd_triples = set_graph(
|
||||
grounding_triples(gnd_uri, q_uri, concepts),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
docs, chunk_ids = await q.get_docs(concepts)
|
||||
|
||||
# Emit exploration explainability after chunks retrieved
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
docrag_exploration_triples(exp_uri, q_uri, len(chunk_ids), chunk_ids),
|
||||
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
|
@ -196,9 +250,8 @@ class DocumentRag:
|
|||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian if callback provided
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
# Generate document ID as URN matching query-time provenance format
|
||||
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
|
|
@ -206,13 +259,11 @@ class DocumentRag:
|
|||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None # Fall back to inline content
|
||||
synthesis_doc_id = None
|
||||
|
||||
# Generate triples with document reference or inline content
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(
|
||||
syn_uri, exp_uri,
|
||||
answer_text="" if synthesis_doc_id else answer_text,
|
||||
document_id=synthesis_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
|
|||
|
|
@ -8,20 +8,23 @@ import uuid
|
|||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
from ... schema import IRI, LITERAL
|
||||
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
|
||||
|
||||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
question_uri,
|
||||
grounding_uri as make_grounding_uri,
|
||||
exploration_uri as make_exploration_uri,
|
||||
focus_uri as make_focus_uri,
|
||||
synthesis_uri as make_synthesis_uri,
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
GRAPH_RETRIEVAL, GRAPH_SOURCE,
|
||||
TG_CONTAINS, PROV_WAS_DERIVED_FROM,
|
||||
)
|
||||
|
||||
# Module logger
|
||||
|
|
@ -47,6 +50,8 @@ def edge_id(s, p, o):
|
|||
edge_str = f"{s}|{p}|{o}"
|
||||
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
|
||||
|
||||
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
|
|
@ -105,42 +110,88 @@ class Query:
|
|||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
|
||||
async def get_vector(self, query):
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Extracted concepts: {concepts}")
|
||||
|
||||
# Fall back to raw query if extraction returns nothing
|
||||
return concepts if concepts else [query]
|
||||
|
||||
async def get_vectors(self, concepts):
|
||||
"""Embed multiple concepts concurrently."""
|
||||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed([query])
|
||||
qembeds = await self.rag.embeddings_client.embed(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
# Return the vector set for the first (only) text
|
||||
return qembeds[0] if qembeds else []
|
||||
return qembeds
|
||||
|
||||
async def get_entities(self, query):
|
||||
"""
|
||||
Extract concepts from query, embed them, and retrieve matching entities.
|
||||
|
||||
vectors = await self.get_vector(query)
|
||||
Returns:
|
||||
tuple: (entities, concepts) where entities is a list of entity URI
|
||||
strings and concepts is the list of concept strings extracted
|
||||
from the query.
|
||||
"""
|
||||
|
||||
concepts = await self.extract_concepts(query)
|
||||
|
||||
vectors = await self.get_vectors(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting entities...")
|
||||
|
||||
entity_matches = await self.rag.graph_embeddings_client.query(
|
||||
vector=vectors, limit=self.entity_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
# Query entity matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.entity_limit // len(vectors)
|
||||
)
|
||||
|
||||
entities = [
|
||||
term_to_string(e.entity)
|
||||
for e in entity_matches
|
||||
entity_tasks = [
|
||||
self.rag.graph_embeddings_client.query(
|
||||
vector=v, limit=per_concept_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
for v in vectors
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*entity_tasks, return_exceptions=True)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen = set()
|
||||
entities = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for e in result:
|
||||
entity = term_to_string(e.entity)
|
||||
if entity not in seen:
|
||||
seen.add(entity)
|
||||
entities.append(entity)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Entities:")
|
||||
for ent in entities:
|
||||
logger.debug(f" {ent}")
|
||||
|
||||
return entities
|
||||
return entities, concepts
|
||||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
|
|
@ -156,6 +207,7 @@ class Query:
|
|||
res = await self.rag.triples_client.query(
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
user=self.user, collection=self.collection,
|
||||
g="",
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
|
|
@ -177,19 +229,19 @@ class Query:
|
|||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
batch_size=20, g="",
|
||||
)
|
||||
])
|
||||
|
||||
|
|
@ -262,8 +314,16 @@ class Query:
|
|||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
"""
|
||||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
entities = await self.get_entities(query)
|
||||
Returns:
|
||||
tuple: (subgraph, entities, concepts) where subgraph is a list of
|
||||
(s, p, o) tuples, entities is the seed entity list, and concepts
|
||||
is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
|
@ -271,7 +331,7 @@ class Query:
|
|||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph)
|
||||
return list(subgraph), entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
|
|
@ -286,11 +346,13 @@ class Query:
|
|||
Get subgraph with labels resolved for display.
|
||||
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map) where:
|
||||
tuple: (labeled_edges, uri_map, entities, concepts) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
"""
|
||||
subgraph = await self.get_subgraph(query)
|
||||
subgraph, entities, concepts = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
|
|
@ -338,8 +400,125 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
return labeled_edges, uri_map
|
||||
|
||||
return labeled_edges, uri_map, entities, concepts
|
||||
|
||||
async def trace_source_documents(self, edge_uris):
|
||||
"""
|
||||
Trace selected edges back to their source documents via provenance.
|
||||
|
||||
Follows the chain: edge → subgraph (via tg:contains) → chunk →
|
||||
page → document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
|
||||
Args:
|
||||
edge_uris: List of (s, p, o) URI string tuples
|
||||
|
||||
Returns:
|
||||
List of unique document titles
|
||||
"""
|
||||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
quoted = Term(
|
||||
type=TRIPLE,
|
||||
triple=SchemaTriple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o),
|
||||
)
|
||||
)
|
||||
subgraph_tasks.append(
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=TG_CONTAINS, o=quoted, limit=1,
|
||||
user=self.user, collection=self.collection,
|
||||
g=GRAPH_SOURCE,
|
||||
)
|
||||
)
|
||||
|
||||
subgraph_results = await asyncio.gather(
|
||||
*subgraph_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# Collect unique subgraph URIs
|
||||
subgraph_uris = set()
|
||||
for result in subgraph_results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for triple in result:
|
||||
subgraph_uris.add(str(triple.s))
|
||||
|
||||
if not subgraph_uris:
|
||||
return []
|
||||
|
||||
# Step 2: Walk prov:wasDerivedFrom chain to find documents
|
||||
# Each level: query ?entity prov:wasDerivedFrom ?parent
|
||||
# Stop when we find entities typed tg:Document
|
||||
current_uris = subgraph_uris
|
||||
doc_uris = set()
|
||||
|
||||
for depth in range(4): # Max depth: subgraph → chunk → page → doc
|
||||
if not current_uris:
|
||||
break
|
||||
|
||||
derivation_tasks = [
|
||||
self.rag.triples_client.query(
|
||||
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
|
||||
user=self.user, collection=self.collection,
|
||||
g=GRAPH_SOURCE,
|
||||
)
|
||||
for uri in current_uris
|
||||
]
|
||||
|
||||
derivation_results = await asyncio.gather(
|
||||
*derivation_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# URIs with no parent are root documents
|
||||
next_uris = set()
|
||||
for uri, result in zip(current_uris, derivation_results):
|
||||
if isinstance(result, Exception) or not result:
|
||||
doc_uris.add(uri)
|
||||
continue
|
||||
for triple in result:
|
||||
next_uris.add(str(triple.o))
|
||||
|
||||
current_uris = next_uris - doc_uris
|
||||
|
||||
if not doc_uris:
|
||||
return []
|
||||
|
||||
# Step 3: Get all document metadata properties
|
||||
# Skip structural predicates that aren't useful context
|
||||
SKIP_PREDICATES = {
|
||||
PROV_WAS_DERIVED_FROM,
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
}
|
||||
|
||||
metadata_tasks = [
|
||||
self.rag.triples_client.query(
|
||||
s=uri, p=None, o=None, limit=50,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
for uri in doc_uris
|
||||
]
|
||||
|
||||
metadata_results = await asyncio.gather(
|
||||
*metadata_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
doc_edges = []
|
||||
for result in metadata_results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for triple in result:
|
||||
p = str(triple.p)
|
||||
if p in SKIP_PREDICATES:
|
||||
continue
|
||||
doc_edges.append((
|
||||
str(triple.s), p, str(triple.o)
|
||||
))
|
||||
|
||||
return doc_edges
|
||||
|
||||
class GraphRag:
|
||||
"""
|
||||
CRITICAL SECURITY:
|
||||
|
|
@ -371,7 +550,8 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, streaming = False, chunk_callback = None,
|
||||
max_path_length = 2, edge_limit = 25, streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -399,6 +579,7 @@ class GraphRag:
|
|||
# Generate explainability URIs upfront
|
||||
session_id = str(uuid.uuid4())
|
||||
q_uri = question_uri(session_id)
|
||||
gnd_uri = make_grounding_uri(session_id)
|
||||
exp_uri = make_exploration_uri(session_id)
|
||||
foc_uri = make_focus_uri(session_id)
|
||||
syn_uri = make_synthesis_uri(session_id)
|
||||
|
|
@ -421,12 +602,23 @@ class GraphRag:
|
|||
max_path_length = max_path_length,
|
||||
)
|
||||
|
||||
kg, uri_map = await q.get_labelgraph(query)
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
|
||||
# Emit grounding explain after concept extraction
|
||||
if explain_callback:
|
||||
gnd_triples = set_graph(
|
||||
grounding_triples(gnd_uri, q_uri, concepts),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
# Emit exploration explain after graph retrieval completes
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
exploration_triples(exp_uri, q_uri, len(kg)),
|
||||
exploration_triples(
|
||||
exp_uri, gnd_uri, len(kg),
|
||||
entities=seed_entities,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
|
@ -453,9 +645,9 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1: Edge Selection - LLM selects relevant edges with reasoning
|
||||
selection_response = await self.prompt_client.prompt(
|
||||
"kg-edge-selection",
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_response = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
|
|
@ -463,52 +655,44 @@ class GraphRag:
|
|||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge selection response: {selection_response}")
|
||||
logger.debug(f"Edge scoring response: {scoring_response}")
|
||||
|
||||
# Parse response to get selected edge IDs and reasoning
|
||||
# Response can be a string (JSONL) or a list (JSON array)
|
||||
selected_ids = set()
|
||||
selected_edges_with_reasoning = [] # For explain
|
||||
# Parse scoring response to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
if isinstance(selection_response, list):
|
||||
# JSON array response
|
||||
for obj in selection_response:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
selected_ids.add(obj["id"])
|
||||
# Capture original URI edge (not labels) and reasoning for explain
|
||||
eid = obj["id"]
|
||||
if eid in uri_map:
|
||||
# Use original URIs for provenance tracing
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": obj.get("reasoning", ""),
|
||||
})
|
||||
elif isinstance(selection_response, str):
|
||||
# JSONL string response
|
||||
for line in selection_response.strip().split('\n'):
|
||||
def parse_scored_edge(obj):
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
except (ValueError, TypeError):
|
||||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
if isinstance(scoring_response, list):
|
||||
for obj in scoring_response:
|
||||
parse_scored_edge(obj)
|
||||
elif isinstance(scoring_response, str):
|
||||
for line in scoring_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
if "id" in obj:
|
||||
selected_ids.add(obj["id"])
|
||||
# Capture original URI edge (not labels) and reasoning for explain
|
||||
eid = obj["id"]
|
||||
if eid in uri_map:
|
||||
# Use original URIs for provenance tracing
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": obj.get("reasoning", ""),
|
||||
})
|
||||
parse_scored_edge(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse edge selection line: {line}")
|
||||
continue
|
||||
logger.warning(
|
||||
f"Failed to parse edge scoring line: {line}"
|
||||
)
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
selected_ids = {e["id"] for e in top_edges}
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}")
|
||||
logger.debug(
|
||||
f"Scored {len(scored_edges)} edges, "
|
||||
f"selected top {len(selected_ids)}"
|
||||
)
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
|
|
@ -516,6 +700,82 @@ class GraphRag:
|
|||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
|
||||
selected_edges_with_ids = [
|
||||
{"id": eid, "s": s, "p": p, "o": o}
|
||||
for eid in selected_ids
|
||||
if eid in edge_map
|
||||
for s, p, o in [edge_map[eid]]
|
||||
]
|
||||
|
||||
# Collect selected edge URIs for document tracing
|
||||
selected_edge_uris = [
|
||||
uri_map[eid]
|
||||
for eid in selected_ids
|
||||
if eid in uri_map
|
||||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
reasoning_task = self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_response, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_response, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_response}"
|
||||
)
|
||||
reasoning_response = ""
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
)
|
||||
source_documents = []
|
||||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning response: {reasoning_response}")
|
||||
|
||||
# Parse reasoning response and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
def parse_reasoning(obj):
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
if isinstance(reasoning_response, list):
|
||||
for obj in reasoning_response:
|
||||
parse_reasoning(obj)
|
||||
elif isinstance(reasoning_response, str):
|
||||
for line in reasoning_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_reasoning(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge reasoning line: {line}"
|
||||
)
|
||||
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
if eid in uri_map:
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": reasoning_map.get(eid, ""),
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
|
|
@ -534,6 +794,18 @@ class GraphRag:
|
|||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
|
||||
# Add source document metadata as knowledge edges
|
||||
for s, p, o in source_documents:
|
||||
selected_edge_dicts.append({
|
||||
"s": s, "p": p, "o": o,
|
||||
})
|
||||
|
||||
synthesis_variables = {
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts,
|
||||
}
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
|
@ -544,10 +816,7 @@ class GraphRag:
|
|||
|
||||
await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts
|
||||
},
|
||||
variables=synthesis_variables,
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
|
|
@ -556,10 +825,7 @@ class GraphRag:
|
|||
else:
|
||||
resp = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts
|
||||
}
|
||||
variables=synthesis_variables,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -570,9 +836,8 @@ class GraphRag:
|
|||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian if callback provided
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
# Generate document ID as URN matching query-time provenance format
|
||||
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
|
|
@ -580,13 +845,11 @@ class GraphRag:
|
|||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None # Fall back to inline content
|
||||
synthesis_doc_id = None
|
||||
|
||||
# Generate triples with document reference or inline content
|
||||
syn_triples = set_graph(
|
||||
synthesis_triples(
|
||||
syn_uri, foc_uri,
|
||||
answer_text="" if synthesis_doc_id else answer_text,
|
||||
document_id=synthesis_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -48,6 +49,7 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -55,6 +57,7 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
|
|
@ -292,6 +295,11 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
edge_limit = self.default_edge_limit
|
||||
|
||||
# Callback to save answer content to librarian
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await self.save_answer_content(
|
||||
|
|
@ -322,6 +330,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
explain_callback = send_explainability,
|
||||
|
|
@ -335,6 +344,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue