From 20bb645b9ac6f2c324175356d2978720b2bb16a5 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Sat, 14 Mar 2026 11:54:10 +0000 Subject: [PATCH] Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding, consistent PROV-O MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit GraphRAG: - Split retrieval into 4 prompt stages: extract-concepts, kg-edge-scoring, kg-edge-reasoning, kg-synthesis (was single-stage) - Add concept extraction (grounding) for per-concept embedding - Filter main query to default graph, ignoring provenance/explainability edges - Add source document edges to knowledge graph DocumentRAG: - Add grounding step with concept extraction, matching GraphRAG's pattern: Question → Grounding → Exploration → Synthesis - Per-concept embedding and chunk retrieval with deduplication Cross-pipeline: - Make PROV-O derivation links consistent: wasGeneratedBy for first entity from Activity, wasDerivedFrom for entity-to-entity chains - Update CLIs (tg-invoke-agent, tg-invoke-graph-rag, tg-invoke-document-rag) for new explainability structure - Fix all affected unit and integration tests --- .../integration/test_graph_rag_integration.py | 31 +- .../test_graph_rag_streaming_integration.py | 12 +- .../test_provenance/test_agent_provenance.py | 202 +++++---- .../test_provenance/test_explainability.py | 184 ++++---- tests/unit/test_provenance/test_triples.py | 129 +++--- .../unit/test_retrieval/test_document_rag.py | 284 +++++++----- tests/unit/test_retrieval/test_graph_rag.py | 376 +++++++--------- trustgraph-base/trustgraph/api/__init__.py | 2 + .../trustgraph/api/explainability.py | 362 ++++++--------- .../trustgraph/base/triples_client.py | 14 +- .../messaging/translators/retrieval.py | 2 + .../trustgraph/provenance/__init__.py | 32 +- .../trustgraph/provenance/agent.py | 111 +++-- .../trustgraph/provenance/namespaces.py | 19 +- .../trustgraph/provenance/triples.py | 94 ++-- trustgraph-base/trustgraph/provenance/uris.py | 55 +++ .../trustgraph/provenance/vocabulary.py | 9 + .../trustgraph/schema/services/retrieval.py | 1 + trustgraph-cli/trustgraph/cli/invoke_agent.py | 11 +- .../trustgraph/cli/invoke_document_rag.py | 11 +- .../trustgraph/cli/invoke_graph_rag.py | 53 ++- .../trustgraph/agent/react/service.py | 33 +- .../retrieval/document_rag/document_rag.py | 87 +++- .../retrieval/graph_rag/graph_rag.py | 421 ++++++++++++++---- .../trustgraph/retrieval/graph_rag/rag.py | 10 + 25 files changed, 1537 insertions(+), 1008 deletions(-) diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 6ff14d69..5e3279e3 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -86,13 +86,18 @@ class TestGraphRagIntegration: """Mock prompt client that generates realistic responses for two-step process""" client = AsyncMock() - # Mock responses for the two-step process: - # 1. kg-edge-selection returns JSONL with edge IDs - # 2. kg-synthesis returns the final answer + # Mock responses for the multi-step process: + # 1. extract-concepts extracts key concepts from the query + # 2. kg-edge-scoring scores edges for relevance + # 3. kg-edge-reasoning provides reasoning for selected edges + # 4. kg-synthesis returns the final answer async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": - # Return empty selection (no edges selected) - valid JSONL - return "" + if prompt_name == "extract-concepts": + return "" # Falls back to raw query + elif prompt_name == "kg-edge-scoring": + return "" # No edges scored + elif prompt_name == "kg-edge-reasoning": + return "" # No reasoning elif prompt_name == "kg-synthesis": return ( "Machine learning is a subset of artificial intelligence that enables computers " @@ -160,16 +165,16 @@ class TestGraphRagIntegration: # 3. Should query triples to build knowledge subgraph assert mock_triples_client.query_stream.call_count > 0 - # 4. Should call prompt twice (edge selection + synthesis) - assert mock_prompt_client.prompt.call_count == 2 + # 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis) + assert mock_prompt_client.prompt.call_count == 4 # Verify final response assert response is not None assert isinstance(response, str) assert "machine learning" in response.lower() - # Verify provenance was emitted in real-time (4 events: question, exploration, focus, synthesis) - assert len(provenance_events) == 4 + # Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis) + assert len(provenance_events) == 5 for triples, prov_id in provenance_events: assert isinstance(triples, list) assert prov_id.startswith("urn:trustgraph:") @@ -243,10 +248,10 @@ class TestGraphRagIntegration: ) # Assert - # Should still call prompt client (twice: edge selection + synthesis) + # Should still call prompt client assert response is not None - # Provenance should still be emitted (4 events) - assert len(provenance_events) == 4 + # Provenance should still be emitted (5 events) + assert len(provenance_events) == 5 @pytest.mark.asyncio async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client): diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index d936de11..b66c5289 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -60,8 +60,12 @@ class TestGraphRagStreaming: full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): - if prompt_id == "kg-edge-selection": - # Edge selection returns JSONL with IDs - simulate selecting first edge + if prompt_id == "extract-concepts": + return "" # Falls back to raw query + elif prompt_id == "kg-edge-scoring": + # Edge scoring returns JSONL with IDs and scores + return '{"id": "abc12345", "score": 0.9}\n' + elif prompt_id == "kg-edge-reasoning": return '{"id": "abc12345", "reasoning": "Relevant to query"}\n' elif prompt_id == "kg-synthesis": if streaming and chunk_callback: @@ -132,8 +136,8 @@ class TestGraphRagStreaming: # Verify content is reasonable assert "machine" in response.lower() or "learning" in response.lower() - # Verify provenance was emitted in real-time (4 events) - assert len(provenance_events) == 4 + # Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis) + assert len(provenance_events) == 5 for triples, prov_id in provenance_events: assert prov_id.startswith("urn:trustgraph:") diff --git a/tests/unit/test_provenance/test_agent_provenance.py b/tests/unit/test_provenance/test_agent_provenance.py index 9377fe19..4efe24c7 100644 --- a/tests/unit/test_provenance/test_agent_provenance.py +++ b/tests/unit/test_provenance/test_agent_provenance.py @@ -15,10 +15,11 @@ from trustgraph.provenance.agent import ( from trustgraph.provenance.namespaces import ( RDF_TYPE, RDFS_LABEL, - PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME, - TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, + PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME, + TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, - TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_AGENT_QUESTION, ) @@ -110,84 +111,107 @@ class TestAgentSessionTriples: class TestAgentIterationTriples: ITER_URI = "urn:trustgraph:agent:test-session/i1" - PARENT_URI = "urn:trustgraph:agent:test-session" + SESSION_URI = "urn:trustgraph:agent:test-session" + PREV_URI = "urn:trustgraph:agent:test-session/i0" def test_iteration_types(self): triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, - thought="thinking", action="search", observation="found it", + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", ) assert has_type(triples, self.ITER_URI, PROV_ENTITY) assert has_type(triples, self.ITER_URI, TG_ANALYSIS) - def test_iteration_derived_from_parent(self): + def test_first_iteration_generated_by_question(self): + """First iteration uses wasGeneratedBy to link to question activity.""" triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, + self.ITER_URI, question_uri=self.SESSION_URI, + action="search", + ) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI) + assert gen is not None + assert gen.o.iri == self.SESSION_URI + # Should NOT have wasDerivedFrom + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) + assert derived is None + + def test_subsequent_iteration_derived_from_previous(self): + """Subsequent iterations use wasDerivedFrom to link to previous iteration.""" + triples = agent_iteration_triples( + self.ITER_URI, previous_uri=self.PREV_URI, action="search", ) derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI) assert derived is not None - assert derived.o.iri == self.PARENT_URI + assert derived.o.iri == self.PREV_URI + # Should NOT have wasGeneratedBy + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI) + assert gen is None def test_iteration_label_includes_action(self): triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, + self.ITER_URI, question_uri=self.SESSION_URI, action="graph-rag-query", ) label = find_triple(triples, RDFS_LABEL, self.ITER_URI) assert label is not None assert "graph-rag-query" in label.o.value - def test_iteration_thought_inline(self): + def test_iteration_thought_sub_entity(self): + """Thought is a sub-entity with Reflection and Thought types.""" + thought_uri = "urn:trustgraph:agent:test-session/i1/thought" + thought_doc = "urn:doc:thought-1" triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, - thought="I need to search for info", + self.ITER_URI, question_uri=self.SESSION_URI, action="search", + thought_uri=thought_uri, + thought_document_id=thought_doc, ) - thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) - assert thought is not None - assert thought.o.value == "I need to search for info" + # Iteration links to thought sub-entity + thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI) + assert thought_link is not None + assert thought_link.o.iri == thought_uri + # Thought has correct types + assert has_type(triples, thought_uri, TG_REFLECTION_TYPE) + assert has_type(triples, thought_uri, TG_THOUGHT_TYPE) + # Thought was generated by iteration + gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri) + assert gen is not None + assert gen.o.iri == self.ITER_URI + # Thought has document reference + doc = find_triple(triples, TG_DOCUMENT, thought_uri) + assert doc is not None + assert doc.o.iri == thought_doc - def test_iteration_thought_document_preferred(self): - """When thought_document_id is provided, inline thought is not stored.""" + def test_iteration_observation_sub_entity(self): + """Observation is a sub-entity with Reflection and Observation types.""" + obs_uri = "urn:trustgraph:agent:test-session/i1/observation" + obs_doc = "urn:doc:obs-1" triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, - thought="inline thought", + self.ITER_URI, question_uri=self.SESSION_URI, action="search", - thought_document_id="urn:doc:thought-1", + observation_uri=obs_uri, + observation_document_id=obs_doc, ) - thought_doc = find_triple(triples, TG_THOUGHT_DOCUMENT, self.ITER_URI) - assert thought_doc is not None - assert thought_doc.o.iri == "urn:doc:thought-1" - thought_inline = find_triple(triples, TG_THOUGHT, self.ITER_URI) - assert thought_inline is None - - def test_iteration_observation_inline(self): - triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, - action="search", - observation="Found 3 results", - ) - obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI) - assert obs is not None - assert obs.o.value == "Found 3 results" - - def test_iteration_observation_document_preferred(self): - triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, - action="search", - observation="inline obs", - observation_document_id="urn:doc:obs-1", - ) - obs_doc = find_triple(triples, TG_OBSERVATION_DOCUMENT, self.ITER_URI) - assert obs_doc is not None - assert obs_doc.o.iri == "urn:doc:obs-1" - obs_inline = find_triple(triples, TG_OBSERVATION, self.ITER_URI) - assert obs_inline is None + # Iteration links to observation sub-entity + obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI) + assert obs_link is not None + assert obs_link.o.iri == obs_uri + # Observation has correct types + assert has_type(triples, obs_uri, TG_REFLECTION_TYPE) + assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE) + # Observation was generated by iteration + gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri) + assert gen is not None + assert gen.o.iri == self.ITER_URI + # Observation has document reference + doc = find_triple(triples, TG_DOCUMENT, obs_uri) + assert doc is not None + assert doc.o.iri == obs_doc def test_iteration_action_recorded(self): triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, + self.ITER_URI, question_uri=self.SESSION_URI, action="graph-rag-query", ) action = find_triple(triples, TG_ACTION, self.ITER_URI) @@ -197,7 +221,7 @@ class TestAgentIterationTriples: def test_iteration_arguments_json_encoded(self): args = {"query": "test query", "limit": 10} triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, + self.ITER_URI, question_uri=self.SESSION_URI, action="search", arguments=args, ) @@ -208,7 +232,7 @@ class TestAgentIterationTriples: def test_iteration_default_arguments_empty_dict(self): triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, + self.ITER_URI, question_uri=self.SESSION_URI, action="search", ) arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI) @@ -219,7 +243,7 @@ class TestAgentIterationTriples: def test_iteration_no_thought_or_observation(self): """Minimal iteration with just action — no thought or observation triples.""" triples = agent_iteration_triples( - self.ITER_URI, self.PARENT_URI, + self.ITER_URI, question_uri=self.SESSION_URI, action="noop", ) thought = find_triple(triples, TG_THOUGHT, self.ITER_URI) @@ -228,19 +252,19 @@ class TestAgentIterationTriples: assert obs is None def test_iteration_chaining(self): - """Second iteration derives from first iteration, not session.""" + """First iteration uses wasGeneratedBy, second uses wasDerivedFrom.""" iter1_uri = "urn:trustgraph:agent:sess/i1" iter2_uri = "urn:trustgraph:agent:sess/i2" triples1 = agent_iteration_triples( - iter1_uri, self.PARENT_URI, action="step1", + iter1_uri, question_uri=self.SESSION_URI, action="step1", ) triples2 = agent_iteration_triples( - iter2_uri, iter1_uri, action="step2", + iter2_uri, previous_uri=iter1_uri, action="step2", ) - derived1 = find_triple(triples1, PROV_WAS_DERIVED_FROM, iter1_uri) - assert derived1.o.iri == self.PARENT_URI + gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri) + assert gen1.o.iri == self.SESSION_URI derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri) assert derived2.o.iri == iter1_uri @@ -253,42 +277,50 @@ class TestAgentIterationTriples: class TestAgentFinalTriples: FINAL_URI = "urn:trustgraph:agent:test-session/final" - PARENT_URI = "urn:trustgraph:agent:test-session/i3" + PREV_URI = "urn:trustgraph:agent:test-session/i3" + SESSION_URI = "urn:trustgraph:agent:test-session" def test_final_types(self): triples = agent_final_triples( - self.FINAL_URI, self.PARENT_URI, answer="42" + self.FINAL_URI, previous_uri=self.PREV_URI, ) assert has_type(triples, self.FINAL_URI, PROV_ENTITY) assert has_type(triples, self.FINAL_URI, TG_CONCLUSION) + assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE) - def test_final_derived_from_parent(self): + def test_final_derived_from_previous(self): + """Conclusion with iterations uses wasDerivedFrom.""" triples = agent_final_triples( - self.FINAL_URI, self.PARENT_URI, answer="42" + self.FINAL_URI, previous_uri=self.PREV_URI, ) derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) assert derived is not None - assert derived.o.iri == self.PARENT_URI + assert derived.o.iri == self.PREV_URI + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI) + assert gen is None + + def test_final_generated_by_question_when_no_iterations(self): + """When agent answers immediately, final uses wasGeneratedBy.""" + triples = agent_final_triples( + self.FINAL_URI, question_uri=self.SESSION_URI, + ) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI) + assert gen is not None + assert gen.o.iri == self.SESSION_URI + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) + assert derived is None def test_final_label(self): triples = agent_final_triples( - self.FINAL_URI, self.PARENT_URI, answer="42" + self.FINAL_URI, previous_uri=self.PREV_URI, ) label = find_triple(triples, RDFS_LABEL, self.FINAL_URI) assert label is not None assert label.o.value == "Conclusion" - def test_final_inline_answer(self): - triples = agent_final_triples( - self.FINAL_URI, self.PARENT_URI, answer="The answer is 42" - ) - answer = find_triple(triples, TG_ANSWER, self.FINAL_URI) - assert answer is not None - assert answer.o.value == "The answer is 42" - def test_final_document_reference(self): triples = agent_final_triples( - self.FINAL_URI, self.PARENT_URI, + self.FINAL_URI, previous_uri=self.PREV_URI, document_id="urn:trustgraph:agent:sess/answer", ) doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) @@ -296,29 +328,9 @@ class TestAgentFinalTriples: assert doc.o.type == IRI assert doc.o.iri == "urn:trustgraph:agent:sess/answer" - def test_final_document_takes_precedence(self): + def test_final_no_document(self): triples = agent_final_triples( - self.FINAL_URI, self.PARENT_URI, - answer="inline", - document_id="urn:doc:123", + self.FINAL_URI, previous_uri=self.PREV_URI, ) doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) - assert doc is not None - answer = find_triple(triples, TG_ANSWER, self.FINAL_URI) - assert answer is None - - def test_final_no_answer_or_document(self): - triples = agent_final_triples(self.FINAL_URI, self.PARENT_URI) - answer = find_triple(triples, TG_ANSWER, self.FINAL_URI) - doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) - assert answer is None assert doc is None - - def test_final_derives_from_session_when_no_iterations(self): - """When agent answers immediately, final derives from session.""" - session_uri = "urn:trustgraph:agent:test-session" - triples = agent_final_triples( - self.FINAL_URI, session_uri, answer="direct answer" - ) - derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI) - assert derived.o.iri == session_uri diff --git a/tests/unit/test_provenance/test_explainability.py b/tests/unit/test_provenance/test_explainability.py index 1f27cc61..a0a0d566 100644 --- a/tests/unit/test_provenance/test_explainability.py +++ b/tests/unit/test_provenance/test_explainability.py @@ -10,9 +10,11 @@ from trustgraph.api.explainability import ( EdgeSelection, ExplainEntity, Question, + Grounding, Exploration, Focus, Synthesis, + Reflection, Analysis, Conclusion, parse_edge_selection_triples, @@ -20,11 +22,11 @@ from trustgraph.api.explainability import ( wire_triples_to_tuples, ExplainabilityClient, TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, - TG_CONTENT, TG_DOCUMENT, TG_CHUNK_COUNT, - TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, - TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, - TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY, + TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANALYSIS, TG_CONCLUSION, + TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY, RDF_TYPE, RDFS_LABEL, @@ -71,6 +73,18 @@ class TestExplainEntityFromTriples: assert isinstance(entity, Question) assert entity.question_type == "agent" + def test_grounding(self): + triples = [ + ("urn:gnd:1", RDF_TYPE, TG_GROUNDING), + ("urn:gnd:1", TG_CONCEPT, "machine learning"), + ("urn:gnd:1", TG_CONCEPT, "neural networks"), + ] + entity = ExplainEntity.from_triples("urn:gnd:1", triples) + assert isinstance(entity, Grounding) + assert len(entity.concepts) == 2 + assert "machine learning" in entity.concepts + assert "neural networks" in entity.concepts + def test_exploration(self): triples = [ ("urn:exp:1", RDF_TYPE, TG_EXPLORATION), @@ -89,6 +103,17 @@ class TestExplainEntityFromTriples: assert isinstance(entity, Exploration) assert entity.chunk_count == 5 + def test_exploration_with_entities(self): + triples = [ + ("urn:exp:3", RDF_TYPE, TG_EXPLORATION), + ("urn:exp:3", TG_EDGE_COUNT, "10"), + ("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"), + ("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"), + ] + entity = ExplainEntity.from_triples("urn:exp:3", triples) + assert isinstance(entity, Exploration) + assert len(entity.entities) == 2 + def test_exploration_invalid_count(self): triples = [ ("urn:exp:3", RDF_TYPE, TG_EXPLORATION), @@ -110,69 +135,76 @@ class TestExplainEntityFromTriples: assert "urn:edge:1" in entity.selected_edge_uris assert "urn:edge:2" in entity.selected_edge_uris - def test_synthesis_with_content(self): + def test_synthesis_with_document(self): triples = [ ("urn:syn:1", RDF_TYPE, TG_SYNTHESIS), - ("urn:syn:1", TG_CONTENT, "The answer is 42"), + ("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"), ] entity = ExplainEntity.from_triples("urn:syn:1", triples) assert isinstance(entity, Synthesis) - assert entity.content == "The answer is 42" - assert entity.document_uri == "" + assert entity.document_uri == "urn:doc:answer-1" - def test_synthesis_with_document(self): + def test_synthesis_no_document(self): triples = [ ("urn:syn:2", RDF_TYPE, TG_SYNTHESIS), - ("urn:syn:2", TG_DOCUMENT, "urn:doc:answer-1"), ] entity = ExplainEntity.from_triples("urn:syn:2", triples) assert isinstance(entity, Synthesis) - assert entity.document_uri == "urn:doc:answer-1" + assert entity.document_uri == "" + + def test_reflection_thought(self): + triples = [ + ("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE), + ("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE), + ("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"), + ] + entity = ExplainEntity.from_triples("urn:ref:1", triples) + assert isinstance(entity, Reflection) + assert entity.reflection_type == "thought" + assert entity.document_uri == "urn:doc:thought-1" + + def test_reflection_observation(self): + triples = [ + ("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE), + ("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE), + ("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"), + ] + entity = ExplainEntity.from_triples("urn:ref:2", triples) + assert isinstance(entity, Reflection) + assert entity.reflection_type == "observation" + assert entity.document_uri == "urn:doc:obs-1" def test_analysis(self): triples = [ ("urn:ana:1", RDF_TYPE, TG_ANALYSIS), - ("urn:ana:1", TG_THOUGHT, "I should search"), ("urn:ana:1", TG_ACTION, "graph-rag-query"), ("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'), - ("urn:ana:1", TG_OBSERVATION, "Found results"), + ("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"), + ("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"), ] entity = ExplainEntity.from_triples("urn:ana:1", triples) assert isinstance(entity, Analysis) - assert entity.thought == "I should search" assert entity.action == "graph-rag-query" assert entity.arguments == '{"query": "test"}' - assert entity.observation == "Found results" - - def test_analysis_with_document_refs(self): - triples = [ - ("urn:ana:2", RDF_TYPE, TG_ANALYSIS), - ("urn:ana:2", TG_ACTION, "search"), - ("urn:ana:2", TG_THOUGHT_DOCUMENT, "urn:doc:thought-1"), - ("urn:ana:2", TG_OBSERVATION_DOCUMENT, "urn:doc:obs-1"), - ] - entity = ExplainEntity.from_triples("urn:ana:2", triples) - assert isinstance(entity, Analysis) - assert entity.thought_document_uri == "urn:doc:thought-1" - assert entity.observation_document_uri == "urn:doc:obs-1" - - def test_conclusion_with_answer(self): - triples = [ - ("urn:conc:1", RDF_TYPE, TG_CONCLUSION), - ("urn:conc:1", TG_ANSWER, "The final answer"), - ] - entity = ExplainEntity.from_triples("urn:conc:1", triples) - assert isinstance(entity, Conclusion) - assert entity.answer == "The final answer" + assert entity.thought_uri == "urn:ref:thought-1" + assert entity.observation_uri == "urn:ref:obs-1" def test_conclusion_with_document(self): + triples = [ + ("urn:conc:1", RDF_TYPE, TG_CONCLUSION), + ("urn:conc:1", TG_DOCUMENT, "urn:doc:final"), + ] + entity = ExplainEntity.from_triples("urn:conc:1", triples) + assert isinstance(entity, Conclusion) + assert entity.document_uri == "urn:doc:final" + + def test_conclusion_no_document(self): triples = [ ("urn:conc:2", RDF_TYPE, TG_CONCLUSION), - ("urn:conc:2", TG_DOCUMENT, "urn:doc:final"), ] entity = ExplainEntity.from_triples("urn:conc:2", triples) assert isinstance(entity, Conclusion) - assert entity.document_uri == "urn:doc:final" + assert entity.document_uri == "" def test_unknown_type(self): triples = [ @@ -457,25 +489,7 @@ class TestExplainabilityClientResolveLabel: class TestExplainabilityClientContentFetching: - def test_fetch_synthesis_inline_content(self): - mock_flow = MagicMock() - client = ExplainabilityClient(mock_flow, retry_delay=0.0) - - synthesis = Synthesis(uri="urn:syn:1", content="inline answer") - result = client.fetch_synthesis_content(synthesis, api=None) - assert result == "inline answer" - - def test_fetch_synthesis_truncated_content(self): - mock_flow = MagicMock() - client = ExplainabilityClient(mock_flow, retry_delay=0.0) - - long_content = "x" * 20000 - synthesis = Synthesis(uri="urn:syn:1", content=long_content) - result = client.fetch_synthesis_content(synthesis, api=None, max_content=100) - assert len(result) < 20000 - assert result.endswith("... [truncated]") - - def test_fetch_synthesis_from_librarian(self): + def test_fetch_document_content_from_librarian(self): mock_flow = MagicMock() mock_api = MagicMock() mock_library = MagicMock() @@ -483,66 +497,32 @@ class TestExplainabilityClientContentFetching: mock_library.get_document_content.return_value = b"librarian content" client = ExplainabilityClient(mock_flow, retry_delay=0.0) - synthesis = Synthesis( - uri="urn:syn:1", - document_uri="urn:document:abc123" + result = client.fetch_document_content( + "urn:document:abc123", api=mock_api ) - result = client.fetch_synthesis_content(synthesis, api=mock_api) assert result == "librarian content" - def test_fetch_synthesis_no_content_or_document(self): - mock_flow = MagicMock() - client = ExplainabilityClient(mock_flow, retry_delay=0.0) - - synthesis = Synthesis(uri="urn:syn:1") - result = client.fetch_synthesis_content(synthesis, api=None) - assert result == "" - - def test_fetch_conclusion_inline(self): - mock_flow = MagicMock() - client = ExplainabilityClient(mock_flow, retry_delay=0.0) - - conclusion = Conclusion(uri="urn:conc:1", answer="42") - result = client.fetch_conclusion_content(conclusion, api=None) - assert result == "42" - - def test_fetch_analysis_content_from_librarian(self): + def test_fetch_document_content_truncated(self): mock_flow = MagicMock() mock_api = MagicMock() mock_library = MagicMock() mock_api.library.return_value = mock_library - mock_library.get_document_content.side_effect = [ - b"thought content", - b"observation content", - ] + mock_library.get_document_content.return_value = b"x" * 20000 client = ExplainabilityClient(mock_flow, retry_delay=0.0) - analysis = Analysis( - uri="urn:ana:1", - action="search", - thought_document_uri="urn:doc:thought", - observation_document_uri="urn:doc:obs", + result = client.fetch_document_content( + "urn:doc:1", api=mock_api, max_content=100 ) - client.fetch_analysis_content(analysis, api=mock_api) - assert analysis.thought == "thought content" - assert analysis.observation == "observation content" + assert len(result) < 20000 + assert result.endswith("... [truncated]") - def test_fetch_analysis_skips_when_inline_exists(self): + def test_fetch_document_content_empty_uri(self): mock_flow = MagicMock() mock_api = MagicMock() client = ExplainabilityClient(mock_flow, retry_delay=0.0) - analysis = Analysis( - uri="urn:ana:1", - action="search", - thought="already have thought", - observation="already have observation", - thought_document_uri="urn:doc:thought", - observation_document_uri="urn:doc:obs", - ) - client.fetch_analysis_content(analysis, api=mock_api) - # Should not call librarian since inline content exists - mock_api.library.assert_not_called() + result = client.fetch_document_content("", api=mock_api) + assert result == "" class TestExplainabilityClientDetectSessionType: diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py index 91074097..9aff7e4b 100644 --- a/tests/unit/test_provenance/test_triples.py +++ b/tests/unit/test_provenance/test_triples.py @@ -13,6 +13,7 @@ from trustgraph.provenance.triples import ( derived_entity_triples, subgraph_provenance_triples, question_triples, + grounding_triples, exploration_triples, focus_triples, synthesis_triples, @@ -32,10 +33,12 @@ from trustgraph.provenance.namespaces import ( TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION, TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS, TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, - TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, - TG_CONTENT, TG_DOCUMENT, + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_DOCUMENT, TG_CHUNK_COUNT, TG_SELECTED_CHUNK, - TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_ANSWER_TYPE, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, GRAPH_SOURCE, GRAPH_RETRIEVAL, ) @@ -530,36 +533,77 @@ class TestQuestionTriples: assert len(triples) == 6 -class TestExplorationTriples: +class TestGroundingTriples: - EXP_URI = "urn:trustgraph:prov:exploration:test-session" + GND_URI = "urn:trustgraph:prov:grounding:test-session" Q_URI = "urn:trustgraph:question:test-session" - def test_exploration_types(self): - triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) - assert has_type(triples, self.EXP_URI, PROV_ENTITY) - assert has_type(triples, self.EXP_URI, TG_EXPLORATION) + def test_grounding_types(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"]) + assert has_type(triples, self.GND_URI, PROV_ENTITY) + assert has_type(triples, self.GND_URI, TG_GROUNDING) - def test_exploration_generated_by_question(self): - triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) - gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI) + def test_grounding_generated_by_question(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"]) + gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI) assert gen is not None assert gen.o.iri == self.Q_URI + def test_grounding_concepts(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"]) + concepts = find_triples(triples, TG_CONCEPT, self.GND_URI) + assert len(concepts) == 3 + values = {t.o.value for t in concepts} + assert values == {"AI", "ML", "robots"} + + def test_grounding_empty_concepts(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, []) + concepts = find_triples(triples, TG_CONCEPT, self.GND_URI) + assert len(concepts) == 0 + + def test_grounding_label(self): + triples = grounding_triples(self.GND_URI, self.Q_URI, []) + label = find_triple(triples, RDFS_LABEL, self.GND_URI) + assert label is not None + assert label.o.value == "Grounding" + + +class TestExplorationTriples: + + EXP_URI = "urn:trustgraph:prov:exploration:test-session" + GND_URI = "urn:trustgraph:prov:grounding:test-session" + + def test_exploration_types(self): + triples = exploration_triples(self.EXP_URI, self.GND_URI, 15) + assert has_type(triples, self.EXP_URI, PROV_ENTITY) + assert has_type(triples, self.EXP_URI, TG_EXPLORATION) + + def test_exploration_derived_from_grounding(self): + triples = exploration_triples(self.EXP_URI, self.GND_URI, 15) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI) + assert derived is not None + assert derived.o.iri == self.GND_URI + def test_exploration_edge_count(self): - triples = exploration_triples(self.EXP_URI, self.Q_URI, 15) + triples = exploration_triples(self.EXP_URI, self.GND_URI, 15) ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) assert ec is not None assert ec.o.value == "15" def test_exploration_zero_edges(self): - triples = exploration_triples(self.EXP_URI, self.Q_URI, 0) + triples = exploration_triples(self.EXP_URI, self.GND_URI, 0) ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI) assert ec is not None assert ec.o.value == "0" + def test_exploration_with_entities(self): + entities = ["urn:e:machine-learning", "urn:e:neural-networks"] + triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities) + ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI) + assert len(ent_triples) == 2 + def test_exploration_triple_count(self): - triples = exploration_triples(self.EXP_URI, self.Q_URI, 10) + triples = exploration_triples(self.EXP_URI, self.GND_URI, 10) assert len(triples) == 5 @@ -652,6 +696,7 @@ class TestSynthesisTriples: triples = synthesis_triples(self.SYN_URI, self.FOC_URI) assert has_type(triples, self.SYN_URI, PROV_ENTITY) assert has_type(triples, self.SYN_URI, TG_SYNTHESIS) + assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE) def test_synthesis_derived_from_focus(self): triples = synthesis_triples(self.SYN_URI, self.FOC_URI) @@ -659,12 +704,6 @@ class TestSynthesisTriples: assert derived is not None assert derived.o.iri == self.FOC_URI - def test_synthesis_with_inline_content(self): - triples = synthesis_triples(self.SYN_URI, self.FOC_URI, answer_text="The answer is 42") - content = find_triple(triples, TG_CONTENT, self.SYN_URI) - assert content is not None - assert content.o.value == "The answer is 42" - def test_synthesis_with_document_reference(self): triples = synthesis_triples( self.SYN_URI, self.FOC_URI, @@ -675,23 +714,9 @@ class TestSynthesisTriples: assert doc.o.type == IRI assert doc.o.iri == "urn:trustgraph:question:abc/answer" - def test_synthesis_document_takes_precedence(self): - """When both document_id and answer_text are provided, document_id wins.""" - triples = synthesis_triples( - self.SYN_URI, self.FOC_URI, - answer_text="inline", - document_id="urn:doc:123", - ) - doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) - assert doc is not None - content = find_triple(triples, TG_CONTENT, self.SYN_URI) - assert content is None - - def test_synthesis_no_content_or_document(self): + def test_synthesis_no_document(self): triples = synthesis_triples(self.SYN_URI, self.FOC_URI) - content = find_triple(triples, TG_CONTENT, self.SYN_URI) doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) - assert content is None assert doc is None @@ -723,31 +748,31 @@ class TestDocRagQuestionTriples: class TestDocRagExplorationTriples: EXP_URI = "urn:trustgraph:docrag:test/exploration" - Q_URI = "urn:trustgraph:docrag:test" + GND_URI = "urn:trustgraph:docrag:test/grounding" def test_docrag_exploration_types(self): - triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5) + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5) assert has_type(triples, self.EXP_URI, PROV_ENTITY) assert has_type(triples, self.EXP_URI, TG_EXPLORATION) - def test_docrag_exploration_generated_by(self): - triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 5) - gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.EXP_URI) - assert gen.o.iri == self.Q_URI + def test_docrag_exploration_derived_from_grounding(self): + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI) + assert derived.o.iri == self.GND_URI def test_docrag_exploration_chunk_count(self): - triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 7) + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7) cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI) assert cc.o.value == "7" def test_docrag_exploration_without_chunk_ids(self): - triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3) + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3) chunks = find_triples(triples, TG_SELECTED_CHUNK) assert len(chunks) == 0 def test_docrag_exploration_with_chunk_ids(self): chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"] - triples = docrag_exploration_triples(self.EXP_URI, self.Q_URI, 3, chunk_ids) + triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids) chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI) assert len(chunks) == 3 chunk_uris = {t.o.iri for t in chunks} @@ -770,10 +795,9 @@ class TestDocRagSynthesisTriples: derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI) assert derived.o.iri == self.EXP_URI - def test_docrag_synthesis_with_inline(self): - triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI, answer_text="answer") - content = find_triple(triples, TG_CONTENT, self.SYN_URI) - assert content.o.value == "answer" + def test_docrag_synthesis_has_answer_type(self): + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE) def test_docrag_synthesis_with_document(self): triples = docrag_synthesis_triples( @@ -781,5 +805,8 @@ class TestDocRagSynthesisTriples: ) doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) assert doc.o.iri == "urn:doc:ans" - content = find_triple(triples, TG_CONTENT, self.SYN_URI) - assert content is None + + def test_docrag_synthesis_no_document(self): + triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI) + doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI) + assert doc is None diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 1f4c5f12..27508ba4 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -125,19 +125,15 @@ class TestQuery: assert query.doc_limit == 50 @pytest.mark.asyncio - async def test_get_vector_method(self): - """Test Query.get_vector method calls embeddings client correctly""" - # Create mock DocumentRag with embeddings client + async def test_extract_concepts(self): + """Test Query.extract_concepts extracts concepts from query""" mock_rag = MagicMock() - mock_embeddings_client = AsyncMock() - mock_rag.embeddings_client = mock_embeddings_client + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client - # Mock the embed method to return test vectors in batch format - # New format: [[[vectors_for_text1]]] - returns first text's vector set - expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - mock_embeddings_client.embed.return_value = [expected_vectors] + # Mock the prompt response with concept lines + mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns" - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -145,20 +141,62 @@ class TestQuery: verbose=False ) - # Call get_vector - test_query = "What documents are relevant?" - result = await query.get_vector(test_query) + result = await query.extract_concepts("What is machine learning?") - # Verify embeddings client was called correctly (now expects list) - mock_embeddings_client.embed.assert_called_once_with([test_query]) + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": "What is machine learning?"} + ) + assert result == ["machine learning", "artificial intelligence", "data patterns"] - # Verify result matches expected vectors (extracted from batch) + @pytest.mark.asyncio + async def test_extract_concepts_fallback_to_raw_query(self): + """Test Query.extract_concepts falls back to raw query when no concepts extracted""" + mock_rag = MagicMock() + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client + + # Mock empty response + mock_prompt_client.prompt.return_value = "" + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + result = await query.extract_concepts("What is ML?") + + assert result == ["What is ML?"] + + @pytest.mark.asyncio + async def test_get_vectors_method(self): + """Test Query.get_vectors method calls embeddings client correctly""" + mock_rag = MagicMock() + mock_embeddings_client = AsyncMock() + mock_rag.embeddings_client = mock_embeddings_client + + # Mock the embed method - returns vectors for each concept + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + mock_embeddings_client.embed.return_value = expected_vectors + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + concepts = ["machine learning", "data patterns"] + result = await query.get_vectors(concepts) + + mock_embeddings_client.embed.assert_called_once_with(concepts) assert result == expected_vectors @pytest.mark.asyncio async def test_get_docs_method(self): """Test Query.get_docs method retrieves documents correctly""" - # Create mock DocumentRag with clients mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() @@ -170,10 +208,8 @@ class TestQuery: return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") mock_rag.fetch_chunk = mock_fetch - # Mock the embedding and document query responses - # New batch format: [[[vectors]]] - get_vector extracts [0] - test_vectors = [[0.1, 0.2, 0.3]] - mock_embeddings_client.embed.return_value = [test_vectors] + # Mock embeddings - one vector per concept + mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]] # Mock document embeddings returns ChunkMatch objects mock_match1 = MagicMock() @@ -184,7 +220,6 @@ class TestQuery: mock_match2.score = 0.85 mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -193,16 +228,16 @@ class TestQuery: doc_limit=15 ) - # Call get_docs - test_query = "Find relevant documents" - result = await query.get_docs(test_query) + # Call get_docs with concepts list + concepts = ["test concept"] + result = await query.get_docs(concepts) - # Verify embeddings client was called (now expects list) - mock_embeddings_client.embed.assert_called_once_with([test_query]) + # Verify embeddings client was called with concepts + mock_embeddings_client.embed.assert_called_once_with(concepts) - # Verify doc embeddings client was called correctly (with extracted vector) + # Verify doc embeddings client was called mock_doc_embeddings_client.query.assert_called_once_with( - vector=test_vectors, + vector=[0.1, 0.2, 0.3], limit=15, user="test_user", collection="test_collection" @@ -218,14 +253,17 @@ class TestQuery: @pytest.mark.asyncio async def test_document_rag_query_method(self, mock_fetch_chunk): """Test DocumentRag.query method orchestrates full document RAG pipeline""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock embeddings and document embeddings responses - # New batch format: [[[vectors]]] - get_vector extracts [0] + # Mock concept extraction + mock_prompt_client.prompt.return_value = "test concept" + + # Mock embeddings - one vector per concept test_vectors = [[0.1, 0.2, 0.3]] + mock_embeddings_client.embed.return_value = test_vectors + mock_match1 = MagicMock() mock_match1.chunk_id = "doc/c3" mock_match1.score = 0.9 @@ -234,11 +272,9 @@ class TestQuery: mock_match2.score = 0.8 expected_response = "This is the document RAG response" - mock_embeddings_client.embed.return_value = [test_vectors] mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_prompt_client.document_prompt.return_value = expected_response - # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, @@ -247,7 +283,6 @@ class TestQuery: verbose=False ) - # Call DocumentRag.query result = await document_rag.query( query="test query", user="test_user", @@ -255,12 +290,18 @@ class TestQuery: doc_limit=10 ) - # Verify embeddings client was called (now expects list) - mock_embeddings_client.embed.assert_called_once_with(["test query"]) + # Verify concept extraction was called + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": "test query"} + ) - # Verify doc embeddings client was called (with extracted vector) + # Verify embeddings called with extracted concepts + mock_embeddings_client.embed.assert_called_once_with(["test concept"]) + + # Verify doc embeddings client was called mock_doc_embeddings_client.query.assert_called_once_with( - vector=test_vectors, + vector=[0.1, 0.2, 0.3], limit=10, user="test_user", collection="test_collection" @@ -270,23 +311,23 @@ class TestQuery: mock_prompt_client.document_prompt.assert_called_once() call_args = mock_prompt_client.document_prompt.call_args assert call_args.kwargs["query"] == "test query" - # Documents should be fetched content, not chunk_ids docs = call_args.kwargs["documents"] assert "Relevant document content" in docs assert "Another document" in docs - # Verify result assert result == expected_response @pytest.mark.asyncio async def test_document_rag_query_with_defaults(self, mock_fetch_chunk): """Test DocumentRag.query method with default parameters""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock responses (batch format) + # Mock concept extraction fallback (empty → raw query) + mock_prompt_client.prompt.return_value = "" + + # Mock responses mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_match = MagicMock() mock_match.chunk_id = "doc/c5" @@ -294,7 +335,6 @@ class TestQuery: mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Default response" - # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, @@ -302,10 +342,9 @@ class TestQuery: fetch_chunk=mock_fetch_chunk ) - # Call DocumentRag.query with minimal parameters result = await document_rag.query("simple query") - # Verify default parameters were used (vector extracted from batch) + # Verify default parameters were used mock_doc_embeddings_client.query.assert_called_once_with( vector=[[0.1, 0.2]], limit=20, # Default doc_limit @@ -318,7 +357,6 @@ class TestQuery: @pytest.mark.asyncio async def test_get_docs_with_verbose_output(self): """Test Query.get_docs method with verbose logging""" - # Create mock DocumentRag with clients mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() @@ -330,14 +368,13 @@ class TestQuery: return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") mock_rag.fetch_chunk = mock_fetch - # Mock responses (batch format) + # Mock responses - one vector per concept mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] mock_match = MagicMock() mock_match.chunk_id = "doc/c6" mock_match.score = 0.88 mock_doc_embeddings_client.query.return_value = [mock_match] - # Initialize Query with verbose=True query = Query( rag=mock_rag, user="test_user", @@ -346,14 +383,12 @@ class TestQuery: doc_limit=5 ) - # Call get_docs - result = await query.get_docs("verbose test") + # Call get_docs with concepts + result = await query.get_docs(["verbose test"]) - # Verify calls were made (now expects list) mock_embeddings_client.embed.assert_called_once_with(["verbose test"]) mock_doc_embeddings_client.query.assert_called_once() - # Verify result is tuple of (docs, chunk_ids) with fetched content docs, chunk_ids = result assert "Verbose test doc" in docs assert "doc/c6" in chunk_ids @@ -361,12 +396,14 @@ class TestQuery: @pytest.mark.asyncio async def test_document_rag_query_with_verbose(self, mock_fetch_chunk): """Test DocumentRag.query method with verbose logging enabled""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock responses (batch format) + # Mock concept extraction + mock_prompt_client.prompt.return_value = "verbose query test" + + # Mock responses mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] mock_match = MagicMock() mock_match.chunk_id = "doc/c7" @@ -374,7 +411,6 @@ class TestQuery: mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Verbose RAG response" - # Initialize DocumentRag with verbose=True document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, @@ -383,14 +419,11 @@ class TestQuery: verbose=True ) - # Call DocumentRag.query result = await document_rag.query("verbose query test") - # Verify all clients were called (now expects list) - mock_embeddings_client.embed.assert_called_once_with(["verbose query test"]) + mock_embeddings_client.embed.assert_called_once() mock_doc_embeddings_client.query.assert_called_once() - # Verify prompt client was called with fetched content call_args = mock_prompt_client.document_prompt.call_args assert call_args.kwargs["query"] == "verbose query test" assert "Verbose doc content" in call_args.kwargs["documents"] @@ -400,23 +433,20 @@ class TestQuery: @pytest.mark.asyncio async def test_get_docs_with_empty_results(self): """Test Query.get_docs method when no documents are found""" - # Create mock DocumentRag with clients mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client - # Mock fetch_chunk (won't be called if no chunk_ids) async def mock_fetch(chunk_id, user): return f"Content for {chunk_id}" mock_rag.fetch_chunk = mock_fetch - # Mock responses - empty chunk_id list (batch format) + # Mock responses - empty results mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] - mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found + mock_doc_embeddings_client.query.return_value = [] - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -424,30 +454,27 @@ class TestQuery: verbose=False ) - # Call get_docs - result = await query.get_docs("query with no results") + result = await query.get_docs(["query with no results"]) - # Verify calls were made (now expects list) mock_embeddings_client.embed.assert_called_once_with(["query with no results"]) mock_doc_embeddings_client.query.assert_called_once() - # Verify empty result is returned (tuple of empty lists) assert result == ([], []) @pytest.mark.asyncio async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk): """Test DocumentRag.query method when no documents are retrieved""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock responses - no chunk_ids found (batch format) + # Mock concept extraction + mock_prompt_client.prompt.return_value = "query with no matching docs" + mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]] - mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list + mock_doc_embeddings_client.query.return_value = [] mock_prompt_client.document_prompt.return_value = "No documents found response" - # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, @@ -456,10 +483,8 @@ class TestQuery: verbose=False ) - # Call DocumentRag.query result = await document_rag.query("query with no matching docs") - # Verify prompt client was called with empty document list mock_prompt_client.document_prompt.assert_called_once_with( query="query with no matching docs", documents=[] @@ -468,18 +493,15 @@ class TestQuery: assert result == "No documents found response" @pytest.mark.asyncio - async def test_get_vector_with_verbose(self): - """Test Query.get_vector method with verbose logging""" - # Create mock DocumentRag with embeddings client + async def test_get_vectors_with_verbose(self): + """Test Query.get_vectors method with verbose logging""" mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - # Mock the embed method (batch format) expected_vectors = [[0.9, 1.0, 1.1]] - mock_embeddings_client.embed.return_value = [expected_vectors] + mock_embeddings_client.embed.return_value = expected_vectors - # Initialize Query with verbose=True query = Query( rag=mock_rag, user="test_user", @@ -487,40 +509,40 @@ class TestQuery: verbose=True ) - # Call get_vector - result = await query.get_vector("verbose vector test") + result = await query.get_vectors(["verbose vector test"]) - # Verify embeddings client was called (now expects list) mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"]) - - # Verify result (extracted from batch) assert result == expected_vectors @pytest.mark.asyncio async def test_document_rag_integration_flow(self, mock_fetch_chunk): """Test complete DocumentRag integration with realistic data flow""" - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock() - # Mock realistic responses (batch format) query_text = "What is machine learning?" - query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] - retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"] final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." - mock_embeddings_client.embed.return_value = [query_vectors] - mock_matches = [] - for chunk_id in retrieved_chunk_ids: - mock_match = MagicMock() - mock_match.chunk_id = chunk_id - mock_match.score = 0.9 - mock_matches.append(mock_match) - mock_doc_embeddings_client.query.return_value = mock_matches + # Mock concept extraction + mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence" + + # Mock embeddings - one vector per concept + query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]] + mock_embeddings_client.embed.return_value = query_vectors + + # Each concept query returns some matches + mock_matches_1 = [ + MagicMock(chunk_id="doc/ml1", score=0.9), + MagicMock(chunk_id="doc/ml2", score=0.85), + ] + mock_matches_2 = [ + MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate + MagicMock(chunk_id="doc/ml3", score=0.82), + ] + mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2] mock_prompt_client.document_prompt.return_value = final_response - # Initialize DocumentRag document_rag = DocumentRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, @@ -529,7 +551,6 @@ class TestQuery: verbose=False ) - # Execute full pipeline result = await document_rag.query( query=query_text, user="research_user", @@ -537,26 +558,69 @@ class TestQuery: doc_limit=25 ) - # Verify complete pipeline execution (now expects list) - mock_embeddings_client.embed.assert_called_once_with([query_text]) - - mock_doc_embeddings_client.query.assert_called_once_with( - vector=query_vectors, - limit=25, - user="research_user", - collection="ml_knowledge" + # Verify concept extraction + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": query_text} ) + # Verify embeddings called with concepts + mock_embeddings_client.embed.assert_called_once_with( + ["machine learning", "artificial intelligence"] + ) + + # Verify two per-concept queries were made (25 // 2 = 12 per concept) + assert mock_doc_embeddings_client.query.call_count == 2 + # Verify prompt client was called with fetched document content mock_prompt_client.document_prompt.assert_called_once() call_args = mock_prompt_client.document_prompt.call_args assert call_args.kwargs["query"] == query_text - # Verify documents were fetched from chunk_ids + # Verify documents were fetched and deduplicated docs = call_args.kwargs["documents"] assert "Machine learning is a subset of artificial intelligence..." in docs assert "ML algorithms learn patterns from data to make predictions..." in docs assert "Common ML techniques include supervised and unsupervised learning..." in docs + assert len(docs) == 3 # doc/ml2 deduplicated - # Verify final result assert result == final_response + + @pytest.mark.asyncio + async def test_get_docs_deduplicates_across_concepts(self): + """Test that get_docs deduplicates chunks across multiple concepts""" + mock_rag = MagicMock() + mock_embeddings_client = AsyncMock() + mock_doc_embeddings_client = AsyncMock() + mock_rag.embeddings_client = mock_embeddings_client + mock_rag.doc_embeddings_client = mock_doc_embeddings_client + + async def mock_fetch(chunk_id, user): + return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") + mock_rag.fetch_chunk = mock_fetch + + # Two concepts → two vectors + mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # Both queries return overlapping chunks + match_a = MagicMock(chunk_id="doc/c1", score=0.9) + match_b = MagicMock(chunk_id="doc/c2", score=0.8) + match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate + mock_doc_embeddings_client.query.side_effect = [ + [match_a, match_b], + [match_c], + ] + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False, + doc_limit=10 + ) + + docs, chunk_ids = await query.get_docs(["concept A", "concept B"]) + + assert len(chunk_ids) == 2 # doc/c1 only counted once + assert "doc/c1" in chunk_ids + assert "doc/c2" in chunk_ids diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 195c8172..597d3366 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -19,7 +19,7 @@ class TestGraphRag: mock_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock() mock_triples_client = MagicMock() - + # Initialize GraphRag graph_rag = GraphRag( prompt_client=mock_prompt_client, @@ -27,7 +27,7 @@ class TestGraphRag: graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client ) - + # Verify initialization assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.embeddings_client == mock_embeddings_client @@ -45,7 +45,7 @@ class TestGraphRag: mock_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock() mock_triples_client = MagicMock() - + # Initialize GraphRag with verbose=True graph_rag = GraphRag( prompt_client=mock_prompt_client, @@ -54,7 +54,7 @@ class TestGraphRag: triples_client=mock_triples_client, verbose=True ) - + # Verify initialization assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.embeddings_client == mock_embeddings_client @@ -73,7 +73,7 @@ class TestQuery: """Test Query initialization with default parameters""" # Create mock GraphRag mock_rag = MagicMock() - + # Initialize Query with defaults query = Query( rag=mock_rag, @@ -81,7 +81,7 @@ class TestQuery: collection="test_collection", verbose=False ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "test_user" @@ -96,7 +96,7 @@ class TestQuery: """Test Query initialization with custom parameters""" # Create mock GraphRag mock_rag = MagicMock() - + # Initialize Query with custom parameters query = Query( rag=mock_rag, @@ -108,7 +108,7 @@ class TestQuery: max_subgraph_size=2000, max_path_length=3 ) - + # Verify initialization assert query.rag == mock_rag assert query.user == "custom_user" @@ -120,18 +120,16 @@ class TestQuery: assert query.max_path_length == 3 @pytest.mark.asyncio - async def test_get_vector_method(self): - """Test Query.get_vector method calls embeddings client correctly""" - # Create mock GraphRag with embeddings client + async def test_get_vectors_method(self): + """Test Query.get_vectors method calls embeddings client correctly""" mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - - # Mock the embed method to return test vectors (batch format) - expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - mock_embeddings_client.embed.return_value = [expected_vectors] - # Initialize Query + # Mock embed to return vectors for a list of concepts + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + mock_embeddings_client.embed.return_value = expected_vectors + query = Query( rag=mock_rag, user="test_user", @@ -139,29 +137,22 @@ class TestQuery: verbose=False ) - # Call get_vector - test_query = "What is the capital of France?" - result = await query.get_vector(test_query) + concepts = ["machine learning", "neural networks"] + result = await query.get_vectors(concepts) - # Verify embeddings client was called correctly (now expects list) - mock_embeddings_client.embed.assert_called_once_with([test_query]) - - # Verify result matches expected vectors (extracted from batch) + mock_embeddings_client.embed.assert_called_once_with(concepts) assert result == expected_vectors @pytest.mark.asyncio - async def test_get_vector_method_with_verbose(self): - """Test Query.get_vector method with verbose output""" - # Create mock GraphRag with embeddings client + async def test_get_vectors_method_with_verbose(self): + """Test Query.get_vectors method with verbose output""" mock_rag = MagicMock() mock_embeddings_client = AsyncMock() mock_rag.embeddings_client = mock_embeddings_client - - # Mock the embed method (batch format) - expected_vectors = [[0.7, 0.8, 0.9]] - mock_embeddings_client.embed.return_value = [expected_vectors] - # Initialize Query with verbose=True + expected_vectors = [[0.7, 0.8, 0.9]] + mock_embeddings_client.embed.return_value = expected_vectors + query = Query( rag=mock_rag, user="test_user", @@ -169,48 +160,87 @@ class TestQuery: verbose=True ) - # Call get_vector - test_query = "Test query for embeddings" - result = await query.get_vector(test_query) + result = await query.get_vectors(["test concept"]) - # Verify embeddings client was called correctly (now expects list) - mock_embeddings_client.embed.assert_called_once_with([test_query]) - - # Verify result matches expected vectors (extracted from batch) + mock_embeddings_client.embed.assert_called_once_with(["test concept"]) assert result == expected_vectors @pytest.mark.asyncio - async def test_get_entities_method(self): - """Test Query.get_entities method retrieves entities correctly""" - # Create mock GraphRag with clients + async def test_extract_concepts(self): + """Test Query.extract_concepts parses LLM response into concept list""" mock_rag = MagicMock() + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client + + mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n" + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + result = await query.extract_concepts("What is machine learning?") + + mock_prompt_client.prompt.assert_called_once_with( + "extract-concepts", + variables={"query": "What is machine learning?"} + ) + assert result == ["machine learning", "neural networks"] + + @pytest.mark.asyncio + async def test_extract_concepts_fallback_to_raw_query(self): + """Test extract_concepts falls back to raw query when LLM returns empty""" + mock_rag = MagicMock() + mock_prompt_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client + + mock_prompt_client.prompt.return_value = "" + + query = Query( + rag=mock_rag, + user="test_user", + collection="test_collection", + verbose=False + ) + + result = await query.extract_concepts("test query") + assert result == ["test query"] + + @pytest.mark.asyncio + async def test_get_entities_method(self): + """Test Query.get_entities extracts concepts, embeds, and retrieves entities""" + mock_rag = MagicMock() + mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() + mock_rag.prompt_client = mock_prompt_client mock_rag.embeddings_client = mock_embeddings_client mock_rag.graph_embeddings_client = mock_graph_embeddings_client - - # Mock the embedding and entity query responses (batch format) - test_vectors = [[0.1, 0.2, 0.3]] - mock_embeddings_client.embed.return_value = [test_vectors] - # Mock EntityMatch objects with entity as Term-like object + # extract_concepts returns empty -> falls back to [query] + mock_prompt_client.prompt.return_value = "" + + # embed returns one vector set for the single concept + test_vectors = [[0.1, 0.2, 0.3]] + mock_embeddings_client.embed.return_value = test_vectors + + # Mock entity matches mock_entity1 = MagicMock() - mock_entity1.type = "i" # IRI type + mock_entity1.type = "i" mock_entity1.iri = "entity1" mock_match1 = MagicMock() mock_match1.entity = mock_entity1 - mock_match1.score = 0.95 mock_entity2 = MagicMock() - mock_entity2.type = "i" # IRI type + mock_entity2.type = "i" mock_entity2.iri = "entity2" mock_match2 = MagicMock() mock_match2.entity = mock_entity2 - mock_match2.score = 0.85 mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2] - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -219,35 +249,23 @@ class TestQuery: entity_limit=25 ) - # Call get_entities - test_query = "Find related entities" - result = await query.get_entities(test_query) + entities, concepts = await query.get_entities("Find related entities") - # Verify embeddings client was called (now expects list) - mock_embeddings_client.embed.assert_called_once_with([test_query]) + # Verify embeddings client was called with the fallback concept + mock_embeddings_client.embed.assert_called_once_with(["Find related entities"]) - # Verify graph embeddings client was called correctly (with extracted vector) - mock_graph_embeddings_client.query.assert_called_once_with( - vector=test_vectors, - limit=25, - user="test_user", - collection="test_collection" - ) - - # Verify result is list of entity strings - assert result == ["entity1", "entity2"] + # Verify result + assert entities == ["entity1", "entity2"] + assert concepts == ["Find related entities"] @pytest.mark.asyncio async def test_maybe_label_with_cached_label(self): """Test Query.maybe_label method with cached label""" - # Create mock GraphRag with label cache mock_rag = MagicMock() - # Create mock LRUCacheWithTTL mock_cache = MagicMock() mock_cache.get.return_value = "Entity One Label" mock_rag.label_cache = mock_cache - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -255,32 +273,25 @@ class TestQuery: verbose=False ) - # Call maybe_label with cached entity result = await query.maybe_label("entity1") - # Verify cached label is returned assert result == "Entity One Label" - # Verify cache was checked with proper key format (user:collection:entity) mock_cache.get.assert_called_once_with("test_user:test_collection:entity1") @pytest.mark.asyncio async def test_maybe_label_with_label_lookup(self): """Test Query.maybe_label method with database label lookup""" - # Create mock GraphRag with triples client mock_rag = MagicMock() - # Create mock LRUCacheWithTTL that returns None (cache miss) mock_cache = MagicMock() mock_cache.get.return_value = None mock_rag.label_cache = mock_cache mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - # Mock triple result with label mock_triple = MagicMock() mock_triple.o = "Human Readable Label" mock_triples_client.query.return_value = [mock_triple] - # Initialize Query query = Query( rag=mock_rag, user="test_user", @@ -288,20 +299,18 @@ class TestQuery: verbose=False ) - # Call maybe_label result = await query.maybe_label("http://example.com/entity") - # Verify triples client was called correctly mock_triples_client.query.assert_called_once_with( s="http://example.com/entity", p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, user="test_user", - collection="test_collection" + collection="test_collection", + g="" ) - # Verify result and cache update with proper key assert result == "Human Readable Label" cache_key = "test_user:test_collection:http://example.com/entity" mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label") @@ -309,40 +318,34 @@ class TestQuery: @pytest.mark.asyncio async def test_maybe_label_with_no_label_found(self): """Test Query.maybe_label method when no label is found""" - # Create mock GraphRag with triples client mock_rag = MagicMock() - # Create mock LRUCacheWithTTL that returns None (cache miss) mock_cache = MagicMock() mock_cache.get.return_value = None mock_rag.label_cache = mock_cache mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Mock empty result (no label found) + mock_triples_client.query.return_value = [] - - # Initialize Query + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=False ) - - # Call maybe_label + result = await query.maybe_label("unlabeled_entity") - - # Verify triples client was called + mock_triples_client.query.assert_called_once_with( s="unlabeled_entity", p="http://www.w3.org/2000/01/rdf-schema#label", o=None, limit=1, user="test_user", - collection="test_collection" + collection="test_collection", + g="" ) - - # Verify result is entity itself and cache is updated + assert result == "unlabeled_entity" cache_key = "test_user:test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") @@ -350,29 +353,25 @@ class TestQuery: @pytest.mark.asyncio async def test_follow_edges_basic_functionality(self): """Test Query.follow_edges method basic triple discovery""" - # Create mock GraphRag with triples client mock_rag = MagicMock() mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Mock triple results for different query patterns + mock_triple1 = MagicMock() mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1" - + mock_triple2 = MagicMock() mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2" - + mock_triple3 = MagicMock() mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1" - - # Setup query_stream responses for s=ent, p=ent, o=ent patterns + mock_triples_client.query_stream.side_effect = [ - [mock_triple1], # s=ent, p=None, o=None - [mock_triple2], # s=None, p=ent, o=None - [mock_triple3], # s=None, p=None, o=ent + [mock_triple1], # s=ent + [mock_triple2], # p=ent + [mock_triple3], # o=ent ] - - # Initialize Query + query = Query( rag=mock_rag, user="test_user", @@ -380,29 +379,25 @@ class TestQuery: verbose=False, triple_limit=10 ) - - # Call follow_edges + subgraph = set() await query.follow_edges("entity1", subgraph, path_length=1) - - # Verify all three query patterns were called + assert mock_triples_client.query_stream.call_count == 3 - # Verify query_stream calls mock_triples_client.query_stream.assert_any_call( s="entity1", p=None, o=None, limit=10, - user="test_user", collection="test_collection", batch_size=20 + user="test_user", collection="test_collection", batch_size=20, g="" ) mock_triples_client.query_stream.assert_any_call( s=None, p="entity1", o=None, limit=10, - user="test_user", collection="test_collection", batch_size=20 + user="test_user", collection="test_collection", batch_size=20, g="" ) mock_triples_client.query_stream.assert_any_call( s=None, p=None, o="entity1", limit=10, - user="test_user", collection="test_collection", batch_size=20 + user="test_user", collection="test_collection", batch_size=20, g="" ) - - # Verify subgraph contains discovered triples + expected_subgraph = { ("entity1", "predicate1", "object1"), ("subject2", "entity1", "object2"), @@ -413,38 +408,30 @@ class TestQuery: @pytest.mark.asyncio async def test_follow_edges_with_path_length_zero(self): """Test Query.follow_edges method with path_length=0""" - # Create mock GraphRag mock_rag = MagicMock() mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Initialize Query + query = Query( rag=mock_rag, user="test_user", collection="test_collection", verbose=False ) - - # Call follow_edges with path_length=0 + subgraph = set() await query.follow_edges("entity1", subgraph, path_length=0) - # Verify no queries were made mock_triples_client.query_stream.assert_not_called() - - # Verify subgraph remains empty assert subgraph == set() @pytest.mark.asyncio async def test_follow_edges_with_max_subgraph_size_limit(self): """Test Query.follow_edges method respects max_subgraph_size""" - # Create mock GraphRag mock_rag = MagicMock() mock_triples_client = AsyncMock() mock_rag.triples_client = mock_triples_client - - # Initialize Query with small max_subgraph_size + query = Query( rag=mock_rag, user="test_user", @@ -452,23 +439,17 @@ class TestQuery: verbose=False, max_subgraph_size=2 ) - - # Pre-populate subgraph to exceed limit + subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")} - - # Call follow_edges + await query.follow_edges("entity1", subgraph, path_length=1) - # Verify no queries were made due to size limit mock_triples_client.query_stream.assert_not_called() - - # Verify subgraph unchanged assert len(subgraph) == 3 @pytest.mark.asyncio async def test_get_subgraph_method(self): - """Test Query.get_subgraph method orchestrates entity and edge discovery""" - # Create mock Query that patches get_entities and follow_edges_batch + """Test Query.get_subgraph returns (subgraph, entities, concepts) tuple""" mock_rag = MagicMock() query = Query( @@ -479,130 +460,119 @@ class TestQuery: max_path_length=1 ) - # Mock get_entities to return test entities - query.get_entities = AsyncMock(return_value=["entity1", "entity2"]) + # Mock get_entities to return (entities, concepts) tuple + query.get_entities = AsyncMock( + return_value=(["entity1", "entity2"], ["concept1"]) + ) - # Mock follow_edges_batch to return test triples query.follow_edges_batch = AsyncMock(return_value={ ("entity1", "predicate1", "object1"), ("entity2", "predicate2", "object2") }) - # Call get_subgraph - result = await query.get_subgraph("test query") + subgraph, entities, concepts = await query.get_subgraph("test query") - # Verify get_entities was called query.get_entities.assert_called_once_with("test query") - - # Verify follow_edges_batch was called with entities and max_path_length query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) - # Verify result is list format and contains expected triples - assert isinstance(result, list) - assert len(result) == 2 - assert ("entity1", "predicate1", "object1") in result - assert ("entity2", "predicate2", "object2") in result + assert isinstance(subgraph, list) + assert len(subgraph) == 2 + assert ("entity1", "predicate1", "object1") in subgraph + assert ("entity2", "predicate2", "object2") in subgraph + assert entities == ["entity1", "entity2"] + assert concepts == ["concept1"] @pytest.mark.asyncio async def test_get_labelgraph_method(self): - """Test Query.get_labelgraph method converts entities to labels""" - # Create mock Query + """Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)""" mock_rag = MagicMock() - + query = Query( rag=mock_rag, user="test_user", - collection="test_collection", + collection="test_collection", verbose=False, max_subgraph_size=100 ) - - # Mock get_subgraph to return test triples + test_subgraph = [ ("entity1", "predicate1", "object1"), - ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered + ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), ("entity3", "predicate3", "object3") ] - query.get_subgraph = AsyncMock(return_value=test_subgraph) - - # Mock maybe_label to return human-readable labels + test_entities = ["entity1", "entity3"] + test_concepts = ["concept1"] + query.get_subgraph = AsyncMock( + return_value=(test_subgraph, test_entities, test_concepts) + ) + async def mock_maybe_label(entity): label_map = { "entity1": "Human Entity One", - "predicate1": "Human Predicate One", + "predicate1": "Human Predicate One", "object1": "Human Object One", "entity3": "Human Entity Three", "predicate3": "Human Predicate Three", "object3": "Human Object Three" } return label_map.get(entity, entity) - - query.maybe_label = AsyncMock(side_effect=mock_maybe_label) - - # Call get_labelgraph - labeled_edges, uri_map = await query.get_labelgraph("test query") - # Verify get_subgraph was called + query.maybe_label = AsyncMock(side_effect=mock_maybe_label) + + labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query") + query.get_subgraph.assert_called_once_with("test query") - # Verify label triples are filtered out - assert len(labeled_edges) == 2 # Label triple should be excluded + # Label triples filtered out + assert len(labeled_edges) == 2 - # Verify maybe_label was called for non-label triples - expected_calls = [ - (("entity1",), {}), (("predicate1",), {}), (("object1",), {}), - (("entity3",), {}), (("predicate3",), {}), (("object3",), {}) - ] + # maybe_label called for non-label triples assert query.maybe_label.call_count == 6 - # Verify result contains human-readable labels expected_edges = [ ("Human Entity One", "Human Predicate One", "Human Object One"), ("Human Entity Three", "Human Predicate Three", "Human Object Three") ] assert labeled_edges == expected_edges - # Verify uri_map maps labeled edges back to original URIs assert len(uri_map) == 2 + assert entities == test_entities + assert concepts == test_concepts @pytest.mark.asyncio async def test_graph_rag_query_method(self): - """Test GraphRag.query method orchestrates full RAG pipeline with real-time provenance""" + """Test GraphRag.query method orchestrates full RAG pipeline with provenance""" import json from trustgraph.retrieval.graph_rag.graph_rag import edge_id - # Create mock clients mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() mock_triples_client = AsyncMock() - # Mock prompt client responses for two-step process expected_response = "This is the RAG response" test_labelgraph = [("Subject", "Predicate", "Object")] - - # Compute the edge ID for the test edge test_edge_id = edge_id("Subject", "Predicate", "Object") - - # Create uri_map for the test edge (maps labeled edge ID to original URIs) test_uri_map = { test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") } + test_entities = ["http://example.org/subject"] + test_concepts = ["test concept"] - # Mock edge selection response (JSONL format) - edge_selection_response = json.dumps({"id": test_edge_id, "reasoning": "relevant"}) - - # Configure prompt mock to return different responses based on prompt name + # Mock prompt responses for the multi-step process async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": - return edge_selection_response + if prompt_name == "extract-concepts": + return "" # Falls back to raw query + elif prompt_name == "kg-edge-scoring": + return json.dumps({"id": test_edge_id, "score": 0.9}) + elif prompt_name == "kg-edge-reasoning": + return json.dumps({"id": test_edge_id, "reasoning": "relevant"}) elif prompt_name == "kg-synthesis": return expected_response return "" mock_prompt_client.prompt = mock_prompt - # Initialize GraphRag graph_rag = GraphRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, @@ -611,27 +581,20 @@ class TestQuery: verbose=False ) - # We need to patch the Query class's get_labelgraph method - original_query_init = Query.__init__ + # Patch Query.get_labelgraph to return test data original_get_labelgraph = Query.get_labelgraph - def mock_query_init(self, *args, **kwargs): - original_query_init(self, *args, **kwargs) - async def mock_get_labelgraph(self, query_text): - return test_labelgraph, test_uri_map + return test_labelgraph, test_uri_map, test_entities, test_concepts - Query.__init__ = mock_query_init Query.get_labelgraph = mock_get_labelgraph - # Collect provenance emitted via callback provenance_events = [] async def collect_provenance(triples, prov_id): provenance_events.append((triples, prov_id)) try: - # Call GraphRag.query with provenance callback response = await graph_rag.query( query="test query", user="test_user", @@ -641,25 +604,22 @@ class TestQuery: explain_callback=collect_provenance ) - # Verify response text assert response == expected_response - # Verify provenance was emitted incrementally (4 events: question, exploration, focus, synthesis) - assert len(provenance_events) == 4 + # 5 events: question, grounding, exploration, focus, synthesis + assert len(provenance_events) == 5 - # Verify each event has triples and a URN for triples, prov_id in provenance_events: assert isinstance(triples, list) assert len(triples) > 0 assert prov_id.startswith("urn:trustgraph:") - # Verify order: question, exploration, focus, synthesis + # Verify order assert "question" in provenance_events[0][1] - assert "exploration" in provenance_events[1][1] - assert "focus" in provenance_events[2][1] - assert "synthesis" in provenance_events[3][1] + assert "grounding" in provenance_events[1][1] + assert "exploration" in provenance_events[2][1] + assert "focus" in provenance_events[3][1] + assert "synthesis" in provenance_events[4][1] finally: - # Restore original methods - Query.__init__ = original_query_init - Query.get_labelgraph = original_get_labelgraph \ No newline at end of file + Query.get_labelgraph = original_get_labelgraph diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index e71e192c..dc1405ac 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -75,9 +75,11 @@ from .explainability import ( ExplainabilityClient, ExplainEntity, Question, + Grounding, Exploration, Focus, Synthesis, + Reflection, Analysis, Conclusion, EdgeSelection, diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index b7ebca0e..26fb77fd 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -18,25 +18,28 @@ TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" -TG_CONTENT = TG + "content" TG_DOCUMENT = TG + "document" +TG_CONCEPT = TG + "concept" +TG_ENTITY = TG + "entity" TG_CHUNK_COUNT = TG + "chunkCount" TG_SELECTED_CHUNK = TG + "selectedChunk" TG_THOUGHT = TG + "thought" TG_ACTION = TG + "action" TG_ARGUMENTS = TG + "arguments" TG_OBSERVATION = TG + "observation" -TG_ANSWER = TG + "answer" -TG_THOUGHT_DOCUMENT = TG + "thoughtDocument" -TG_OBSERVATION_DOCUMENT = TG + "observationDocument" # Entity types TG_QUESTION = TG + "Question" +TG_GROUNDING = TG + "Grounding" TG_EXPLORATION = TG + "Exploration" TG_FOCUS = TG + "Focus" TG_SYNTHESIS = TG + "Synthesis" TG_ANALYSIS = TG + "Analysis" TG_CONCLUSION = TG + "Conclusion" +TG_ANSWER_TYPE = TG + "Answer" +TG_REFLECTION_TYPE = TG + "Reflection" +TG_THOUGHT_TYPE = TG + "Thought" +TG_OBSERVATION_TYPE = TG + "Observation" TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" TG_AGENT_QUESTION = TG + "AgentQuestion" @@ -73,12 +76,16 @@ class ExplainEntity: if TG_GRAPH_RAG_QUESTION in types or TG_DOC_RAG_QUESTION in types or TG_AGENT_QUESTION in types: return Question.from_triples(uri, triples, types) + elif TG_GROUNDING in types: + return Grounding.from_triples(uri, triples) elif TG_EXPLORATION in types: return Exploration.from_triples(uri, triples) elif TG_FOCUS in types: return Focus.from_triples(uri, triples) elif TG_SYNTHESIS in types: return Synthesis.from_triples(uri, triples) + elif TG_REFLECTION_TYPE in types: + return Reflection.from_triples(uri, triples) elif TG_ANALYSIS in types: return Analysis.from_triples(uri, triples) elif TG_CONCLUSION in types: @@ -124,16 +131,38 @@ class Question(ExplainEntity): ) +@dataclass +class Grounding(ExplainEntity): + """Grounding entity - concept decomposition of the query.""" + concepts: List[str] = field(default_factory=list) + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Grounding": + concepts = [] + + for s, p, o in triples: + if p == TG_CONCEPT: + concepts.append(o) + + return cls( + uri=uri, + entity_type="grounding", + concepts=concepts + ) + + @dataclass class Exploration(ExplainEntity): """Exploration entity - edges/chunks retrieved from the knowledge store.""" edge_count: int = 0 chunk_count: int = 0 + entities: List[str] = field(default_factory=list) @classmethod def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Exploration": edge_count = 0 chunk_count = 0 + entities = [] for s, p, o in triples: if p == TG_EDGE_COUNT: @@ -146,12 +175,15 @@ class Exploration(ExplainEntity): chunk_count = int(o) except (ValueError, TypeError): pass + elif p == TG_ENTITY: + entities.append(o) return cls( uri=uri, entity_type="exploration", edge_count=edge_count, - chunk_count=chunk_count + chunk_count=chunk_count, + entities=entities ) @@ -180,94 +212,104 @@ class Focus(ExplainEntity): @dataclass class Synthesis(ExplainEntity): """Synthesis entity - the final answer.""" - content: str = "" document_uri: str = "" # Reference to librarian document @classmethod def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Synthesis": - content = "" document_uri = "" for s, p, o in triples: - if p == TG_CONTENT: - content = o - elif p == TG_DOCUMENT: + if p == TG_DOCUMENT: document_uri = o return cls( uri=uri, entity_type="synthesis", - content=content, document_uri=document_uri ) +@dataclass +class Reflection(ExplainEntity): + """Reflection entity - intermediate commentary (Thought or Observation).""" + document_uri: str = "" # Reference to content in librarian + reflection_type: str = "" # "thought" or "observation" + + @classmethod + def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Reflection": + document_uri = "" + reflection_type = "" + + types = [o for s, p, o in triples if p == RDF_TYPE] + + if TG_THOUGHT_TYPE in types: + reflection_type = "thought" + elif TG_OBSERVATION_TYPE in types: + reflection_type = "observation" + + for s, p, o in triples: + if p == TG_DOCUMENT: + document_uri = o + + return cls( + uri=uri, + entity_type="reflection", + document_uri=document_uri, + reflection_type=reflection_type + ) + + @dataclass class Analysis(ExplainEntity): """Analysis entity - one think/act/observe cycle (Agent only).""" - thought: str = "" action: str = "" arguments: str = "" # JSON string - observation: str = "" - thought_document_uri: str = "" # Reference to thought in librarian - observation_document_uri: str = "" # Reference to observation in librarian + thought_uri: str = "" # URI of thought sub-entity + observation_uri: str = "" # URI of observation sub-entity @classmethod def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Analysis": - thought = "" action = "" arguments = "" - observation = "" - thought_document_uri = "" - observation_document_uri = "" + thought_uri = "" + observation_uri = "" for s, p, o in triples: - if p == TG_THOUGHT: - thought = o - elif p == TG_ACTION: + if p == TG_ACTION: action = o elif p == TG_ARGUMENTS: arguments = o + elif p == TG_THOUGHT: + thought_uri = o elif p == TG_OBSERVATION: - observation = o - elif p == TG_THOUGHT_DOCUMENT: - thought_document_uri = o - elif p == TG_OBSERVATION_DOCUMENT: - observation_document_uri = o + observation_uri = o return cls( uri=uri, entity_type="analysis", - thought=thought, action=action, arguments=arguments, - observation=observation, - thought_document_uri=thought_document_uri, - observation_document_uri=observation_document_uri + thought_uri=thought_uri, + observation_uri=observation_uri ) @dataclass class Conclusion(ExplainEntity): """Conclusion entity - final answer (Agent only).""" - answer: str = "" document_uri: str = "" # Reference to librarian document @classmethod def from_triples(cls, uri: str, triples: List[Tuple[str, str, Any]]) -> "Conclusion": - answer = "" document_uri = "" for s, p, o in triples: - if p == TG_ANSWER: - answer = o - elif p == TG_DOCUMENT: + if p == TG_DOCUMENT: document_uri = o return cls( uri=uri, entity_type="conclusion", - answer=answer, document_uri=document_uri ) @@ -543,42 +585,29 @@ class ExplainabilityClient: o_label = self.resolve_label(edge.get("o", ""), user, collection) return (s_label, p_label, o_label) - def fetch_synthesis_content( + def fetch_document_content( self, - synthesis: Synthesis, + document_uri: str, api: Any, user: Optional[str] = None, max_content: int = 10000 ) -> str: """ - Fetch the content for a Synthesis entity. - - If synthesis has inline content, returns that. - If synthesis has a document_uri, fetches from librarian with retry. + Fetch content from the librarian by document URI. Args: - synthesis: The Synthesis entity + document_uri: The document URI in the librarian api: TrustGraph Api instance for librarian access user: User identifier for librarian max_content: Maximum content length to return Returns: - The synthesis content as a string + The document content as a string """ - # If inline content exists, use it - if synthesis.content: - if len(synthesis.content) > max_content: - return synthesis.content[:max_content] + "... [truncated]" - return synthesis.content - - # Otherwise fetch from librarian - if not synthesis.document_uri: + if not document_uri: return "" - # Extract document ID from URI (e.g., "urn:document:abc123" -> "abc123") - doc_id = synthesis.document_uri - if doc_id.startswith("urn:document:"): - doc_id = doc_id[len("urn:document:"):] + doc_id = document_uri # Retry fetching from librarian for eventual consistency for attempt in range(self.max_retries): @@ -603,129 +632,6 @@ class ExplainabilityClient: return "" - def fetch_conclusion_content( - self, - conclusion: Conclusion, - api: Any, - user: Optional[str] = None, - max_content: int = 10000 - ) -> str: - """ - Fetch the content for a Conclusion entity (Agent final answer). - - If conclusion has inline answer, returns that. - If conclusion has a document_uri, fetches from librarian with retry. - - Args: - conclusion: The Conclusion entity - api: TrustGraph Api instance for librarian access - user: User identifier for librarian - max_content: Maximum content length to return - - Returns: - The conclusion answer as a string - """ - # If inline answer exists, use it - if conclusion.answer: - if len(conclusion.answer) > max_content: - return conclusion.answer[:max_content] + "... [truncated]" - return conclusion.answer - - # Otherwise fetch from librarian - if not conclusion.document_uri: - return "" - - # Use document URI directly (it's already a full URN) - doc_id = conclusion.document_uri - - # Retry fetching from librarian for eventual consistency - for attempt in range(self.max_retries): - try: - library = api.library() - content_bytes = library.get_document_content(user=user, id=doc_id) - - # Decode as text - try: - content = content_bytes.decode('utf-8') - if len(content) > max_content: - return content[:max_content] + "... [truncated]" - return content - except UnicodeDecodeError: - return f"[Binary: {len(content_bytes)} bytes]" - - except Exception as e: - if attempt < self.max_retries - 1: - time.sleep(self.retry_delay) - continue - return f"[Error fetching content: {e}]" - - return "" - - def fetch_analysis_content( - self, - analysis: Analysis, - api: Any, - user: Optional[str] = None, - max_content: int = 10000 - ) -> None: - """ - Fetch thought and observation content for an Analysis entity. - - If analysis has inline content, uses that. - If analysis has document URIs, fetches from librarian with retry. - Modifies the analysis object in place. - - Args: - analysis: The Analysis entity (modified in place) - api: TrustGraph Api instance for librarian access - user: User identifier for librarian - max_content: Maximum content length to return - """ - # Fetch thought if needed - if not analysis.thought and analysis.thought_document_uri: - doc_id = analysis.thought_document_uri - for attempt in range(self.max_retries): - try: - library = api.library() - content_bytes = library.get_document_content(user=user, id=doc_id) - try: - content = content_bytes.decode('utf-8') - if len(content) > max_content: - analysis.thought = content[:max_content] + "... [truncated]" - else: - analysis.thought = content - break - except UnicodeDecodeError: - analysis.thought = f"[Binary: {len(content_bytes)} bytes]" - break - except Exception as e: - if attempt < self.max_retries - 1: - time.sleep(self.retry_delay) - continue - analysis.thought = f"[Error fetching thought: {e}]" - - # Fetch observation if needed - if not analysis.observation and analysis.observation_document_uri: - doc_id = analysis.observation_document_uri - for attempt in range(self.max_retries): - try: - library = api.library() - content_bytes = library.get_document_content(user=user, id=doc_id) - try: - content = content_bytes.decode('utf-8') - if len(content) > max_content: - analysis.observation = content[:max_content] + "... [truncated]" - else: - analysis.observation = content - break - except UnicodeDecodeError: - analysis.observation = f"[Binary: {len(content_bytes)} bytes]" - break - except Exception as e: - if attempt < self.max_retries - 1: - time.sleep(self.retry_delay) - continue - analysis.observation = f"[Error fetching observation: {e}]" def fetch_graphrag_trace( self, @@ -739,7 +645,7 @@ class ExplainabilityClient: """ Fetch the complete GraphRAG trace starting from a question URI. - Follows the provenance chain: Question -> Exploration -> Focus -> Synthesis + Follows the provenance chain: Question -> Grounding -> Exploration -> Focus -> Synthesis Args: question_uri: The question entity URI @@ -750,13 +656,14 @@ class ExplainabilityClient: max_content: Maximum content length for synthesis Returns: - Dict with question, exploration, focus, synthesis entities + Dict with question, grounding, exploration, focus, synthesis entities """ if graph is None: graph = "urn:graph:retrieval" trace = { "question": None, + "grounding": None, "exploration": None, "focus": None, "synthesis": None, @@ -768,8 +675,8 @@ class ExplainabilityClient: return trace trace["question"] = question - # Find exploration: ?exploration prov:wasGeneratedBy question_uri - exploration_triples = self.flow.triples_query( + # Find grounding: ?grounding prov:wasGeneratedBy question_uri + grounding_triples = self.flow.triples_query( p=PROV_WAS_GENERATED_BY, o=question_uri, g=graph, @@ -778,6 +685,30 @@ class ExplainabilityClient: limit=10 ) + if grounding_triples: + grounding_uris = [ + extract_term_value(t.get("s", {})) + for t in grounding_triples + ] + for gnd_uri in grounding_uris: + grounding = self.fetch_entity(gnd_uri, graph, user, collection) + if isinstance(grounding, Grounding): + trace["grounding"] = grounding + break + + if not trace["grounding"]: + return trace + + # Find exploration: ?exploration prov:wasDerivedFrom grounding_uri + exploration_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=trace["grounding"].uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + if exploration_triples: exploration_uris = [ extract_term_value(t.get("s", {})) @@ -834,11 +765,6 @@ class ExplainabilityClient: for synth_uri in synthesis_uris: synthesis = self.fetch_entity(synth_uri, graph, user, collection) if isinstance(synthesis, Synthesis): - # Fetch content if needed - if api and not synthesis.content and synthesis.document_uri: - synthesis.content = self.fetch_synthesis_content( - synthesis, api, user, max_content - ) trace["synthesis"] = synthesis break @@ -928,11 +854,6 @@ class ExplainabilityClient: for synth_uri in synthesis_uris: synthesis = self.fetch_entity(synth_uri, graph, user, collection) if isinstance(synthesis, Synthesis): - # Fetch content if needed - if api and not synthesis.content and synthesis.document_uri: - synthesis.content = self.fetch_synthesis_content( - synthesis, api, user, max_content - ) trace["synthesis"] = synthesis break @@ -978,20 +899,43 @@ class ExplainabilityClient: return trace trace["question"] = question - # Follow the chain of wasDerivedFrom + # Follow the chain: wasGeneratedBy for first hop, wasDerivedFrom after current_uri = session_uri + is_first = True max_iterations = 50 # Safety limit for _ in range(max_iterations): - # Find entity derived from current - derived_triples = self.flow.triples_query( - p=PROV_WAS_DERIVED_FROM, - o=current_uri, - g=graph, - user=user, - collection=collection, - limit=10 - ) + # First hop uses wasGeneratedBy (entity←activity), + # subsequent hops use wasDerivedFrom (entity←entity) + if is_first: + derived_triples = self.flow.triples_query( + p=PROV_WAS_GENERATED_BY, + o=current_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + # Fall back to wasDerivedFrom for backwards compatibility + if not derived_triples: + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=current_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) + is_first = False + else: + derived_triples = self.flow.triples_query( + p=PROV_WAS_DERIVED_FROM, + o=current_uri, + g=graph, + user=user, + collection=collection, + limit=10 + ) if not derived_triples: break @@ -1003,19 +947,9 @@ class ExplainabilityClient: entity = self.fetch_entity(derived_uri, graph, user, collection) if isinstance(entity, Analysis): - # Fetch thought/observation content from librarian if needed - if api: - self.fetch_analysis_content( - entity, api, user=user, max_content=max_content - ) trace["iterations"].append(entity) current_uri = derived_uri elif isinstance(entity, Conclusion): - # Fetch answer content from librarian if needed - if api and not entity.answer and entity.document_uri: - entity.answer = self.fetch_conclusion_content( - entity, api, user=user, max_content=max_content - ) trace["conclusion"] = entity break else: diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 84e95ebe..e661f46d 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,6 +1,6 @@ from . request_response_spec import RequestResponse, RequestResponseSpec -from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL +from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL, TRIPLE from .. knowledge import Uri, Literal @@ -22,9 +22,11 @@ def to_value(x): def from_value(x): - """Convert Uri, Literal, or string to schema Term.""" + """Convert Uri, Literal, string, or Term to schema Term.""" if x is None: return None + if isinstance(x, Term): + return x if isinstance(x, Uri): return Term(type=IRI, iri=str(x)) elif isinstance(x, Literal): @@ -41,7 +43,7 @@ def from_value(x): class TriplesClient(RequestResponse): async def query(self, s=None, p=None, o=None, limit=20, user="trustgraph", collection="default", - timeout=30): + timeout=30, g=None): resp = await self.request( TriplesQueryRequest( @@ -51,6 +53,7 @@ class TriplesClient(RequestResponse): limit = limit, user = user, collection = collection, + g = g, ), timeout=timeout ) @@ -68,7 +71,7 @@ class TriplesClient(RequestResponse): async def query_stream(self, s=None, p=None, o=None, limit=20, user="trustgraph", collection="default", batch_size=20, timeout=30, - batch_callback=None): + batch_callback=None, g=None): """ Streaming triple query - calls callback for each batch as it arrives. @@ -80,6 +83,8 @@ class TriplesClient(RequestResponse): batch_size: Triples per batch timeout: Request timeout in seconds batch_callback: Async callback(batch, is_final) called for each batch + g: Graph filter. ""=default graph only, None=all graphs, + or a specific graph IRI. Returns: List[Triple]: All triples (flattened) if no callback provided @@ -112,6 +117,7 @@ class TriplesClient(RequestResponse): collection=collection, streaming=True, batch_size=batch_size, + g=g, ), timeout=timeout, recipient=recipient, diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 7326b722..b7ff818c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator): triple_limit=int(data.get("triple-limit", 30)), max_subgraph_size=int(data.get("max-subgraph-size", 1000)), max_path_length=int(data.get("max-path-length", 2)), + edge_limit=int(data.get("edge-limit", 25)), streaming=data.get("streaming", False) ) @@ -96,6 +97,7 @@ class GraphRagRequestTranslator(MessageTranslator): "triple-limit": obj.triple_limit, "max-subgraph-size": obj.max_subgraph_size, "max-path-length": obj.max_path_length, + "edge-limit": obj.edge_limit, "streaming": getattr(obj, "streaming", False) } diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index df3c2034..18ecb0e8 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -42,15 +42,19 @@ from . uris import ( agent_uri, # Query-time provenance URIs (GraphRAG) question_uri, + grounding_uri, exploration_uri, focus_uri, synthesis_uri, # Agent provenance URIs agent_session_uri, agent_iteration_uri, + agent_thought_uri, + agent_observation_uri, agent_final_uri, # Document RAG provenance URIs docrag_question_uri, + docrag_grounding_uri, docrag_exploration_uri, docrag_synthesis_uri, ) @@ -74,18 +78,19 @@ from . namespaces import ( # Extraction provenance entity types TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, # Query-time provenance predicates (GraphRAG) - TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_CONTENT, + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, # Explainability entity types - TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANALYSIS, TG_CONCLUSION, + # Unifying types + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, # Question subtypes (to distinguish retrieval mechanism) TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION, # Agent provenance predicates - TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, - # Agent document references - TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, + TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, # Document reference predicate TG_DOCUMENT, # Named graphs @@ -99,6 +104,7 @@ from . triples import ( subgraph_provenance_triples, # Query-time provenance triple builders (GraphRAG) question_triples, + grounding_triples, exploration_triples, focus_triples, synthesis_triples, @@ -139,15 +145,19 @@ __all__ = [ "agent_uri", # Query-time provenance URIs "question_uri", + "grounding_uri", "exploration_uri", "focus_uri", "synthesis_uri", # Agent provenance URIs "agent_session_uri", "agent_iteration_uri", + "agent_thought_uri", + "agent_observation_uri", "agent_final_uri", # Document RAG provenance URIs "docrag_question_uri", + "docrag_grounding_uri", "docrag_exploration_uri", "docrag_synthesis_uri", # Namespaces @@ -164,18 +174,19 @@ __all__ = [ # Extraction provenance entity types "TG_DOCUMENT_TYPE", "TG_PAGE_TYPE", "TG_CHUNK_TYPE", "TG_SUBGRAPH_TYPE", # Query-time provenance predicates (GraphRAG) - "TG_QUERY", "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_CONTENT", + "TG_QUERY", "TG_CONCEPT", "TG_ENTITY", + "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", # Query-time provenance predicates (DocumentRAG) "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", # Explainability entity types - "TG_QUESTION", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", + "TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", "TG_ANALYSIS", "TG_CONCLUSION", + # Unifying types + "TG_ANSWER_TYPE", "TG_REFLECTION_TYPE", "TG_THOUGHT_TYPE", "TG_OBSERVATION_TYPE", # Question subtypes "TG_GRAPH_RAG_QUESTION", "TG_DOC_RAG_QUESTION", "TG_AGENT_QUESTION", # Agent provenance predicates - "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_ANSWER", - # Agent document references - "TG_THOUGHT_DOCUMENT", "TG_OBSERVATION_DOCUMENT", + "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", # Document reference predicate "TG_DOCUMENT", # Named graphs @@ -186,6 +197,7 @@ __all__ = [ "subgraph_provenance_triples", # Query-time provenance triple builders (GraphRAG) "question_triples", + "grounding_triples", "exploration_triples", "focus_triples", "synthesis_triples", diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py index e0ee9841..f1aeab0d 100644 --- a/trustgraph-base/trustgraph/provenance/agent.py +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -15,10 +15,11 @@ from .. schema import Triple, Term, IRI, LITERAL from . namespaces import ( RDF_TYPE, RDFS_LABEL, - PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, PROV_STARTED_AT_TIME, - TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_ANSWER, + PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM, + PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME, + TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, - TG_THOUGHT_DOCUMENT, TG_OBSERVATION_DOCUMENT, + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_AGENT_QUESTION, ) @@ -73,12 +74,13 @@ def agent_session_triples( def agent_iteration_triples( iteration_uri: str, - parent_uri: str, - thought: str = "", + question_uri: Optional[str] = None, + previous_uri: Optional[str] = None, action: str = "", arguments: Dict[str, Any] = None, - observation: str = "", + thought_uri: Optional[str] = None, thought_document_id: Optional[str] = None, + observation_uri: Optional[str] = None, observation_document_id: Optional[str] = None, ) -> List[Triple]: """ @@ -86,19 +88,22 @@ def agent_iteration_triples( Creates: - Entity declaration with tg:Analysis type - - wasDerivedFrom link to parent (previous iteration or session) - - Thought, action, arguments, and observation data - - Document references for thought/observation when stored in librarian + - wasGeneratedBy link to question (if first iteration) + - wasDerivedFrom link to previous iteration (if not first) + - Action and arguments metadata + - Thought sub-entity (tg:Reflection, tg:Thought) with librarian document + - Observation sub-entity (tg:Reflection, tg:Observation) with librarian document Args: iteration_uri: URI of this iteration (from agent_iteration_uri) - parent_uri: URI of the parent (previous iteration or session) - thought: The agent's reasoning/thought (used if thought_document_id not provided) + question_uri: URI of the question activity (for first iteration) + previous_uri: URI of the previous iteration (for subsequent iterations) action: The tool/action name arguments: Arguments passed to the tool (will be JSON-encoded) - observation: The result/observation from the tool (used if observation_document_id not provided) - thought_document_id: Optional document URI for thought in librarian (preferred) - observation_document_id: Optional document URI for observation in librarian (preferred) + thought_uri: URI for the thought sub-entity + thought_document_id: Document URI for thought in librarian + observation_uri: URI for the observation sub-entity + observation_document_id: Document URI for observation in librarian Returns: List of Triple objects @@ -110,45 +115,70 @@ def agent_iteration_triples( _triple(iteration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(iteration_uri, RDF_TYPE, _iri(TG_ANALYSIS)), _triple(iteration_uri, RDFS_LABEL, _literal(f"Analysis: {action}")), - _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)), _triple(iteration_uri, TG_ACTION, _literal(action)), _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), ] - # Thought: use document reference or inline - if thought_document_id: - triples.append(_triple(iteration_uri, TG_THOUGHT_DOCUMENT, _iri(thought_document_id))) - elif thought: - triples.append(_triple(iteration_uri, TG_THOUGHT, _literal(thought))) + if question_uri: + triples.append( + _triple(iteration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)) + ) + elif previous_uri: + triples.append( + _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri)) + ) - # Observation: use document reference or inline - if observation_document_id: - triples.append(_triple(iteration_uri, TG_OBSERVATION_DOCUMENT, _iri(observation_document_id))) - elif observation: - triples.append(_triple(iteration_uri, TG_OBSERVATION, _literal(observation))) + # Thought sub-entity + if thought_uri: + triples.extend([ + _triple(iteration_uri, TG_THOUGHT, _iri(thought_uri)), + _triple(thought_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)), + _triple(thought_uri, RDF_TYPE, _iri(TG_THOUGHT_TYPE)), + _triple(thought_uri, RDFS_LABEL, _literal("Thought")), + _triple(thought_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)), + ]) + if thought_document_id: + triples.append( + _triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id)) + ) + + # Observation sub-entity + if observation_uri: + triples.extend([ + _triple(iteration_uri, TG_OBSERVATION, _iri(observation_uri)), + _triple(observation_uri, RDF_TYPE, _iri(TG_REFLECTION_TYPE)), + _triple(observation_uri, RDF_TYPE, _iri(TG_OBSERVATION_TYPE)), + _triple(observation_uri, RDFS_LABEL, _literal("Observation")), + _triple(observation_uri, PROV_WAS_GENERATED_BY, _iri(iteration_uri)), + ]) + if observation_document_id: + triples.append( + _triple(observation_uri, TG_DOCUMENT, _iri(observation_document_id)) + ) return triples def agent_final_triples( final_uri: str, - parent_uri: str, - answer: str = "", + question_uri: Optional[str] = None, + previous_uri: Optional[str] = None, document_id: Optional[str] = None, ) -> List[Triple]: """ Build triples for an agent final answer (Conclusion). Creates: - - Entity declaration with tg:Conclusion type - - wasDerivedFrom link to parent (last iteration or session) - - Either document reference (if document_id provided) or inline answer + - Entity declaration with tg:Conclusion and tg:Answer types + - wasGeneratedBy link to question (if no iterations) + - wasDerivedFrom link to last iteration (if iterations exist) + - Document reference to librarian Args: final_uri: URI of the final answer (from agent_final_uri) - parent_uri: URI of the parent (last iteration or session if no iterations) - answer: The final answer text (used if document_id not provided) - document_id: Optional document URI in librarian (preferred) + question_uri: URI of the question activity (if no iterations) + previous_uri: URI of the last iteration (if iterations exist) + document_id: Librarian document ID for the answer content Returns: List of Triple objects @@ -156,15 +186,20 @@ def agent_final_triples( triples = [ _triple(final_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(final_uri, RDF_TYPE, _iri(TG_CONCLUSION)), + _triple(final_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), _triple(final_uri, RDFS_LABEL, _literal("Conclusion")), - _triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(parent_uri)), ] + if question_uri: + triples.append( + _triple(final_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)) + ) + elif previous_uri: + triples.append( + _triple(final_uri, PROV_WAS_DERIVED_FROM, _iri(previous_uri)) + ) + if document_id: - # Store reference to document in librarian (as IRI) triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) - elif answer: - # Fallback: store inline answer - triples.append(_triple(final_uri, TG_ANSWER, _literal(answer))) return triples diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 4c1ab7bf..066e893b 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -60,11 +60,12 @@ TG_SOURCE_CHAR_LENGTH = TG + "sourceCharLength" # Query-time provenance predicates (GraphRAG) TG_QUERY = TG + "query" +TG_CONCEPT = TG + "concept" +TG_ENTITY = TG + "entity" TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" -TG_CONTENT = TG + "content" TG_DOCUMENT = TG + "document" # Reference to document in librarian # Query-time provenance predicates (DocumentRAG) @@ -79,27 +80,29 @@ TG_SUBGRAPH_TYPE = TG + "Subgraph" # Explainability entity types (shared) TG_QUESTION = TG + "Question" +TG_GROUNDING = TG + "Grounding" TG_EXPLORATION = TG + "Exploration" TG_FOCUS = TG + "Focus" TG_SYNTHESIS = TG + "Synthesis" TG_ANALYSIS = TG + "Analysis" TG_CONCLUSION = TG + "Conclusion" +# Unifying types for answer and intermediate commentary +TG_ANSWER_TYPE = TG + "Answer" # Final answer (Synthesis, Conclusion) +TG_REFLECTION_TYPE = TG + "Reflection" # Intermediate commentary (Thought, Observation) +TG_THOUGHT_TYPE = TG + "Thought" # Agent reasoning +TG_OBSERVATION_TYPE = TG + "Observation" # Agent tool result + # Question subtypes (to distinguish retrieval mechanism) TG_GRAPH_RAG_QUESTION = TG + "GraphRagQuestion" TG_DOC_RAG_QUESTION = TG + "DocRagQuestion" TG_AGENT_QUESTION = TG + "AgentQuestion" # Agent provenance predicates -TG_THOUGHT = TG + "thought" +TG_THOUGHT = TG + "thought" # Links iteration to thought sub-entity TG_ACTION = TG + "action" TG_ARGUMENTS = TG + "arguments" -TG_OBSERVATION = TG + "observation" -TG_ANSWER = TG + "answer" - -# Agent document references (for librarian storage) -TG_THOUGHT_DOCUMENT = TG + "thoughtDocument" -TG_OBSERVATION_DOCUMENT = TG + "observationDocument" +TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity # Named graph URIs for RDF datasets # These separate different types of data while keeping them in the same collection diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 459783d1..60d8d8f6 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -20,12 +20,15 @@ from . namespaces import ( # Extraction provenance entity types TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, # Query-time provenance predicates (GraphRAG) - TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_CONTENT, + TG_QUERY, TG_CONCEPT, TG_ENTITY, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_DOCUMENT, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, # Explainability entity types - TG_QUESTION, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + # Unifying types + TG_ANSWER_TYPE, # Question subtypes TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, ) @@ -347,35 +350,78 @@ def question_triples( ] +def grounding_triples( + grounding_uri: str, + question_uri: str, + concepts: List[str], +) -> List[Triple]: + """ + Build triples for a grounding entity (concept decomposition of query). + + Creates: + - Entity declaration for grounding + - wasGeneratedBy link to question + - Concept literals for each extracted concept + + Args: + grounding_uri: URI of the grounding entity (from grounding_uri) + question_uri: URI of the parent question + concepts: List of concept strings extracted from the query + + Returns: + List of Triple objects + """ + triples = [ + _triple(grounding_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(grounding_uri, RDF_TYPE, _iri(TG_GROUNDING)), + _triple(grounding_uri, RDFS_LABEL, _literal("Grounding")), + _triple(grounding_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)), + ] + + for concept in concepts: + triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept))) + + return triples + + def exploration_triples( exploration_uri: str, - question_uri: str, + grounding_uri: str, edge_count: int, + entities: Optional[List[str]] = None, ) -> List[Triple]: """ Build triples for an exploration entity (all edges retrieved from subgraph). Creates: - Entity declaration for exploration - - wasGeneratedBy link to question + - wasDerivedFrom link to grounding - Edge count metadata + - Entity IRIs for each seed entity Args: exploration_uri: URI of the exploration entity (from exploration_uri) - question_uri: URI of the parent question + grounding_uri: URI of the parent grounding entity edge_count: Number of edges retrieved + entities: Optional list of seed entity URIs Returns: List of Triple objects """ - return [ + triples = [ _triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)), _triple(exploration_uri, RDFS_LABEL, _literal("Exploration")), - _triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)), + _triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)), _triple(exploration_uri, TG_EDGE_COUNT, _literal(edge_count)), ] + if entities: + for entity in entities: + triples.append(_triple(exploration_uri, TG_ENTITY, _iri(entity))) + + return triples + def _quoted_triple(s: str, p: str, o: str) -> Term: """Create a quoted triple term (RDF-star) from string values.""" @@ -454,22 +500,20 @@ def focus_triples( def synthesis_triples( synthesis_uri: str, focus_uri: str, - answer_text: str = "", document_id: Optional[str] = None, ) -> List[Triple]: """ - Build triples for a synthesis entity (final answer text). + Build triples for a synthesis entity (final answer). Creates: - - Entity declaration for synthesis + - Entity declaration for synthesis with tg:Answer type - wasDerivedFrom link to focus - - Either document reference (if document_id provided) or inline content + - Document reference to librarian Args: synthesis_uri: URI of the synthesis entity (from synthesis_uri) focus_uri: URI of the parent focus entity - answer_text: The synthesized answer text (used if no document_id) - document_id: Optional librarian document ID (preferred over inline content) + document_id: Librarian document ID for the answer content Returns: List of Triple objects @@ -477,16 +521,13 @@ def synthesis_triples( triples = [ _triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)), + _triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), _triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")), _triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(focus_uri)), ] if document_id: - # Store reference to document in librarian (as IRI) triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) - elif answer_text: - # Fallback: store inline content - triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text))) return triples @@ -533,7 +574,7 @@ def docrag_question_triples( def docrag_exploration_triples( exploration_uri: str, - question_uri: str, + grounding_uri: str, chunk_count: int, chunk_ids: Optional[List[str]] = None, ) -> List[Triple]: @@ -542,12 +583,12 @@ def docrag_exploration_triples( Creates: - Entity declaration with tg:Exploration type - - wasGeneratedBy link to question + - wasDerivedFrom link to grounding - Chunk count and optional chunk references Args: exploration_uri: URI of the exploration entity - question_uri: URI of the parent question + grounding_uri: URI of the parent grounding entity chunk_count: Number of chunks retrieved chunk_ids: Optional list of chunk URIs/IDs @@ -558,7 +599,7 @@ def docrag_exploration_triples( _triple(exploration_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(exploration_uri, RDF_TYPE, _iri(TG_EXPLORATION)), _triple(exploration_uri, RDFS_LABEL, _literal("Exploration")), - _triple(exploration_uri, PROV_WAS_GENERATED_BY, _iri(question_uri)), + _triple(exploration_uri, PROV_WAS_DERIVED_FROM, _iri(grounding_uri)), _triple(exploration_uri, TG_CHUNK_COUNT, _literal(chunk_count)), ] @@ -573,22 +614,20 @@ def docrag_exploration_triples( def docrag_synthesis_triples( synthesis_uri: str, exploration_uri: str, - answer_text: str = "", document_id: Optional[str] = None, ) -> List[Triple]: """ Build triples for a document RAG synthesis entity (final answer). Creates: - - Entity declaration with tg:Synthesis type + - Entity declaration with tg:Synthesis and tg:Answer types - wasDerivedFrom link to exploration (skips focus step) - - Either document reference or inline content + - Document reference to librarian Args: synthesis_uri: URI of the synthesis entity exploration_uri: URI of the parent exploration entity - answer_text: The synthesized answer text (used if no document_id) - document_id: Optional librarian document ID (preferred over inline content) + document_id: Librarian document ID for the answer content Returns: List of Triple objects @@ -596,13 +635,12 @@ def docrag_synthesis_triples( triples = [ _triple(synthesis_uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(synthesis_uri, RDF_TYPE, _iri(TG_SYNTHESIS)), + _triple(synthesis_uri, RDF_TYPE, _iri(TG_ANSWER_TYPE)), _triple(synthesis_uri, RDFS_LABEL, _literal("Synthesis")), _triple(synthesis_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), ] if document_id: triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) - elif answer_text: - triples.append(_triple(synthesis_uri, TG_CONTENT, _literal(answer_text))) return triples diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index 0f8d3136..670143df 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -68,6 +68,7 @@ def agent_uri(component_name: str) -> str: # # Terminology: # Question - What was asked, the anchor for everything +# Grounding - Decomposing the question into concepts # Exploration - Casting wide, what do we know about this space # Focus - Closing down, what's actually relevant here # Synthesis - Weaving the relevant pieces into an answer @@ -87,6 +88,19 @@ def question_uri(session_id: str = None) -> str: return f"urn:trustgraph:question:{session_id}" +def grounding_uri(session_id: str) -> str: + """ + Generate URI for a grounding entity (concept decomposition of query). + + Args: + session_id: The session UUID (same as question_uri). + + Returns: + URN in format: urn:trustgraph:prov:grounding:{uuid} + """ + return f"urn:trustgraph:prov:grounding:{session_id}" + + def exploration_uri(session_id: str) -> str: """ Generate URI for an exploration entity (edges retrieved from subgraph). @@ -173,6 +187,34 @@ def agent_iteration_uri(session_id: str, iteration_num: int) -> str: return f"urn:trustgraph:agent:{session_id}/i{iteration_num}" +def agent_thought_uri(session_id: str, iteration_num: int) -> str: + """ + Generate URI for an agent thought sub-entity. + + Args: + session_id: The session UUID. + iteration_num: 1-based iteration number. + + Returns: + URN in format: urn:trustgraph:agent:{uuid}/i{num}/thought + """ + return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" + + +def agent_observation_uri(session_id: str, iteration_num: int) -> str: + """ + Generate URI for an agent observation sub-entity. + + Args: + session_id: The session UUID. + iteration_num: 1-based iteration number. + + Returns: + URN in format: urn:trustgraph:agent:{uuid}/i{num}/observation + """ + return f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" + + def agent_final_uri(session_id: str) -> str: """ Generate URI for an agent final answer. @@ -205,6 +247,19 @@ def docrag_question_uri(session_id: str = None) -> str: return f"urn:trustgraph:docrag:{session_id}" +def docrag_grounding_uri(session_id: str) -> str: + """ + Generate URI for a document RAG grounding entity (concept decomposition). + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid}/grounding + """ + return f"urn:trustgraph:docrag:{session_id}/grounding" + + def docrag_exploration_uri(session_id: str) -> str: """ Generate URI for a document RAG exploration entity (chunks retrieved). diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py index 4ad2e59b..018e2bfe 100644 --- a/trustgraph-base/trustgraph/provenance/vocabulary.py +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -25,6 +25,8 @@ from . namespaces import ( TG_LLM_MODEL, TG_ONTOLOGY, TG_EMBEDDING_MODEL, TG_SOURCE_TEXT, TG_SOURCE_CHAR_OFFSET, TG_SOURCE_CHAR_LENGTH, TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE, + TG_CONCEPT, TG_ENTITY, TG_GROUNDING, + TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, ) @@ -80,6 +82,11 @@ TG_CLASS_LABELS = [ _label_triple(TG_PAGE_TYPE, "Page"), _label_triple(TG_CHUNK_TYPE, "Chunk"), _label_triple(TG_SUBGRAPH_TYPE, "Subgraph"), + _label_triple(TG_GROUNDING, "Grounding"), + _label_triple(TG_ANSWER_TYPE, "Answer"), + _label_triple(TG_REFLECTION_TYPE, "Reflection"), + _label_triple(TG_THOUGHT_TYPE, "Thought"), + _label_triple(TG_OBSERVATION_TYPE, "Observation"), ] # TrustGraph predicate labels @@ -100,6 +107,8 @@ TG_PREDICATE_LABELS = [ _label_triple(TG_SOURCE_TEXT, "source text"), _label_triple(TG_SOURCE_CHAR_OFFSET, "source character offset"), _label_triple(TG_SOURCE_CHAR_LENGTH, "source character length"), + _label_triple(TG_CONCEPT, "concept"), + _label_triple(TG_ENTITY, "entity"), ] diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 0d0b79b8..b3a9d58d 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -15,6 +15,7 @@ class GraphRagQuery: triple_limit: int = 0 max_subgraph_size: int = 0 max_path_length: int = 0 + edge_limit: int = 0 streaming: bool = False @dataclass diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 03b71e2a..dedb2f34 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -202,16 +202,17 @@ def question_explainable( elif isinstance(entity, Analysis): print(f"\n [iteration] {prov_id}", file=sys.stderr) - if entity.thought: - thought_short = entity.thought[:80] + "..." if len(entity.thought) > 80 else entity.thought - print(f" Thought: {thought_short}", file=sys.stderr) if entity.action: print(f" Action: {entity.action}", file=sys.stderr) + if entity.thought_uri: + print(f" Thought: {entity.thought_uri}", file=sys.stderr) + if entity.observation_uri: + print(f" Observation: {entity.observation_uri}", file=sys.stderr) elif isinstance(entity, Conclusion): print(f"\n [conclusion] {prov_id}", file=sys.stderr) - if entity.answer: - print(f" Answer length: {len(entity.answer)} chars", file=sys.stderr) + if entity.document_uri: + print(f" Document: {entity.document_uri}", file=sys.stderr) else: if debug: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 4ed7bca9..381b6924 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -11,6 +11,7 @@ from trustgraph.api import ( RAGChunk, ProvenanceEvent, Question, + Grounding, Exploration, Synthesis, ) @@ -68,6 +69,12 @@ def question_explainable( if entity.timestamp: print(f" Time: {entity.timestamp}", file=sys.stderr) + elif isinstance(entity, Grounding): + print(f"\n [grounding] {prov_id}", file=sys.stderr) + if entity.concepts: + for concept in entity.concepts: + print(f" Concept: {concept}", file=sys.stderr) + elif isinstance(entity, Exploration): print(f"\n [exploration] {prov_id}", file=sys.stderr) if entity.chunk_count: @@ -75,8 +82,8 @@ def question_explainable( elif isinstance(entity, Synthesis): print(f"\n [synthesis] {prov_id}", file=sys.stderr) - if entity.content: - print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr) + if entity.document_uri: + print(f" Document: {entity.document_uri}", file=sys.stderr) else: if debug: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 295df0b9..870576c3 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -14,6 +14,7 @@ from trustgraph.api import ( RAGChunk, ProvenanceEvent, Question, + Grounding, Exploration, Focus, Synthesis, @@ -31,11 +32,13 @@ default_max_path_length = 2 # Provenance predicates TG = "https://trustgraph.ai/ns/" TG_QUERY = TG + "query" +TG_CONCEPT = TG + "concept" +TG_ENTITY = TG + "entity" TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" -TG_CONTENT = TG + "content" +TG_DOCUMENT = TG + "document" TG_CONTAINS = TG + "contains" PROV = "http://www.w3.org/ns/prov#" PROV_STARTED_AT_TIME = PROV + "startedAtTime" @@ -47,6 +50,8 @@ def _get_event_type(prov_id): """Extract event type from provenance_id""" if "question" in prov_id: return "question" + elif "grounding" in prov_id: + return "grounding" elif "exploration" in prov_id: return "exploration" elif "focus" in prov_id: @@ -68,8 +73,16 @@ def _format_provenance_details(event_type, triples): elif p == PROV_STARTED_AT_TIME: lines.append(f" Time: {o}") + elif event_type == "grounding": + # Show extracted concepts + concepts = [o for s, p, o in triples if p == TG_CONCEPT] + if concepts: + lines.append(f" Concepts: {len(concepts)}") + for concept in concepts: + lines.append(f" - {concept}") + elif event_type == "exploration": - # Show edge count + # Show edge count (seed entities resolved separately with labels) for s, p, o in triples: if p == TG_EDGE_COUNT: lines.append(f" Edges explored: {o}") @@ -85,10 +98,10 @@ def _format_provenance_details(event_type, triples): lines.append(f" Focused on {len(edge_sel_uris)} edge(s)") elif event_type == "synthesis": - # Show content length (not full content - it's already streamed) + # Show document reference (content already streamed) for s, p, o in triples: - if p == TG_CONTENT: - lines.append(f" Synthesis length: {len(o)} chars") + if p == TG_DOCUMENT: + lines.append(f" Document: {o}") return lines @@ -542,6 +555,18 @@ async def _question_explainable( for line in details: print(line, file=sys.stderr) + # For exploration events, resolve entity labels + if event_type == "exploration": + entity_iris = [o for s, p, o in triples if p == TG_ENTITY] + if entity_iris: + print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) + for iri in entity_iris: + label = await _query_label( + ws_url, flow_id, iri, user, collection, + label_cache, debug=debug + ) + print(f" - {label}", file=sys.stderr) + # For focus events, query each edge selection for details if event_type == "focus": for s, p, o in triples: @@ -660,10 +685,22 @@ def _question_explainable_api( if entity.timestamp: print(f" Time: {entity.timestamp}", file=sys.stderr) + elif isinstance(entity, Grounding): + print(f"\n [grounding] {prov_id}", file=sys.stderr) + if entity.concepts: + print(f" Concepts: {len(entity.concepts)}", file=sys.stderr) + for concept in entity.concepts: + print(f" - {concept}", file=sys.stderr) + elif isinstance(entity, Exploration): print(f"\n [exploration] {prov_id}", file=sys.stderr) if entity.edge_count: print(f" Edges explored: {entity.edge_count}", file=sys.stderr) + if entity.entities: + print(f" Seed entities: {len(entity.entities)}", file=sys.stderr) + for ent in entity.entities: + label = explain_client.resolve_label(ent, user, collection) + print(f" - {label}", file=sys.stderr) elif isinstance(entity, Focus): print(f"\n [focus] {prov_id}", file=sys.stderr) @@ -691,8 +728,8 @@ def _question_explainable_api( elif isinstance(entity, Synthesis): print(f"\n [synthesis] {prov_id}", file=sys.stderr) - if entity.content: - print(f" Synthesis length: {len(entity.content)} chars", file=sys.stderr) + if entity.document_uri: + print(f" Document: {entity.document_uri}", file=sys.stderr) else: if debug: @@ -848,7 +885,7 @@ def main(): parser.add_argument( '-x', '--explainable', action='store_true', - help='Show provenance events: Question, Exploration, Focus, Synthesis (implies streaming)' + help='Show provenance events: Question, Grounding, Exploration, Focus, Synthesis (implies streaming)' ) parser.add_argument( diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 306c081e..9a02e5c6 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -31,6 +31,8 @@ from ... schema import librarian_request_queue, librarian_response_queue from trustgraph.provenance import ( agent_session_uri, agent_iteration_uri, + agent_thought_uri, + agent_observation_uri, agent_final_uri, agent_session_triples, agent_iteration_triples, @@ -624,11 +626,13 @@ class Processor(AgentService): # Emit final answer provenance triples final_uri = agent_final_uri(session_id) - # Parent is last iteration, or session if no iterations + # No iterations: link to question; otherwise: link to last iteration if iteration_num > 1: - parent_uri = agent_iteration_uri(session_id, iteration_num - 1) + final_question_uri = None + final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) else: - parent_uri = session_uri + final_question_uri = session_uri + final_previous_uri = None # Save answer to librarian answer_doc_id = None @@ -648,8 +652,9 @@ class Processor(AgentService): final_triples = set_graph( agent_final_triples( - final_uri, parent_uri, - answer="" if answer_doc_id else f, + final_uri, + question_uri=final_question_uri, + previous_uri=final_previous_uri, document_id=answer_doc_id, ), GRAPH_RETRIEVAL @@ -707,11 +712,13 @@ class Processor(AgentService): # Emit iteration provenance triples iteration_uri = agent_iteration_uri(session_id, iteration_num) - # Parent is previous iteration, or session if this is first iteration + # First iteration links to question, subsequent to previous if iteration_num > 1: - parent_uri = agent_iteration_uri(session_id, iteration_num - 1) + iter_question_uri = None + iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1) else: - parent_uri = session_uri + iter_question_uri = session_uri + iter_previous_uri = None # Save thought to librarian thought_doc_id = None @@ -745,15 +752,19 @@ class Processor(AgentService): logger.warning(f"Failed to save observation to librarian: {e}") observation_doc_id = None + thought_entity_uri = agent_thought_uri(session_id, iteration_num) + observation_entity_uri = agent_observation_uri(session_id, iteration_num) + iter_triples = set_graph( agent_iteration_triples( iteration_uri, - parent_uri, - thought="" if thought_doc_id else act.thought, + question_uri=iter_question_uri, + previous_uri=iter_previous_uri, action=act.name, arguments=act.arguments, - observation="" if observation_doc_id else act.observation, + thought_uri=thought_entity_uri if thought_doc_id else None, thought_document_id=thought_doc_id, + observation_uri=observation_entity_uri if observation_doc_id else None, observation_document_id=observation_doc_id, ), GRAPH_RETRIEVAL diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 78c97024..730a7226 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -7,9 +7,11 @@ from datetime import datetime # Provenance imports from trustgraph.provenance import ( docrag_question_uri, + docrag_grounding_uri, docrag_exploration_uri, docrag_synthesis_uri, docrag_question_triples, + grounding_triples, docrag_exploration_triples, docrag_synthesis_triples, set_graph, @@ -33,39 +35,79 @@ class Query: self.verbose = verbose self.doc_limit = doc_limit - async def get_vector(self, query): + async def extract_concepts(self, query): + """Extract key concepts from query for independent embedding.""" + response = await self.rag.prompt_client.prompt( + "extract-concepts", + variables={"query": query} + ) + concepts = [] + if isinstance(response, str): + for line in response.strip().split('\n'): + line = line.strip() + if line: + concepts.append(line) + + # Fallback to raw query if no concepts extracted + if not concepts: + concepts = [query] + + if self.verbose: + logger.debug(f"Extracted concepts: {concepts}") + + return concepts + + async def get_vectors(self, concepts): + """Compute embeddings for a list of concepts.""" if self.verbose: logger.debug("Computing embeddings...") - qembeds = await self.rag.embeddings_client.embed([query]) + qembeds = await self.rag.embeddings_client.embed(concepts) if self.verbose: logger.debug("Embeddings computed") - # Return the vector set for the first (only) text - return qembeds[0] if qembeds else [] + return qembeds - async def get_docs(self, query): + async def get_docs(self, concepts): """ - Get documents (chunks) matching the query. + Get documents (chunks) matching the extracted concepts. Returns: tuple: (docs, chunk_ids) where: - docs: list of document content strings - chunk_ids: list of chunk IDs that were successfully fetched """ - vectors = await self.get_vector(query) + vectors = await self.get_vectors(concepts) if self.verbose: logger.debug("Getting chunks from embeddings store...") - # Get chunk matches from embeddings store - chunk_matches = await self.rag.doc_embeddings_client.query( - vector=vectors, limit=self.doc_limit, - user=self.user, collection=self.collection, + # Query chunk matches for each concept concurrently + per_concept_limit = max( + 1, self.doc_limit // len(vectors) ) + async def query_concept(vec): + return await self.rag.doc_embeddings_client.query( + vector=vec, limit=per_concept_limit, + user=self.user, collection=self.collection, + ) + + results = await asyncio.gather( + *[query_concept(v) for v in vectors] + ) + + # Deduplicate chunk matches by chunk_id + seen = set() + chunk_matches = [] + for matches in results: + for match in matches: + if match.chunk_id and match.chunk_id not in seen: + seen.add(match.chunk_id) + chunk_matches.append(match) + if self.verbose: logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...") @@ -133,6 +175,7 @@ class DocumentRag: # Generate explainability URIs upfront session_id = str(uuid.uuid4()) q_uri = docrag_question_uri(session_id) + gnd_uri = docrag_grounding_uri(session_id) exp_uri = docrag_exploration_uri(session_id) syn_uri = docrag_synthesis_uri(session_id) @@ -151,12 +194,23 @@ class DocumentRag: doc_limit=doc_limit ) - docs, chunk_ids = await q.get_docs(query) + # Extract concepts from query (grounding step) + concepts = await q.extract_concepts(query) + + # Emit grounding explainability after concept extraction + if explain_callback: + gnd_triples = set_graph( + grounding_triples(gnd_uri, q_uri, concepts), + GRAPH_RETRIEVAL + ) + await explain_callback(gnd_triples, gnd_uri) + + docs, chunk_ids = await q.get_docs(concepts) # Emit exploration explainability after chunks retrieved if explain_callback: exp_triples = set_graph( - docrag_exploration_triples(exp_uri, q_uri, len(chunk_ids), chunk_ids), + docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids), GRAPH_RETRIEVAL ) await explain_callback(exp_triples, exp_uri) @@ -196,9 +250,8 @@ class DocumentRag: synthesis_doc_id = None answer_text = resp if resp else "" - # Save answer to librarian if callback provided + # Save answer to librarian if save_answer_callback and answer_text: - # Generate document ID as URN matching query-time provenance format synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer" try: await save_answer_callback(synthesis_doc_id, answer_text) @@ -206,13 +259,11 @@ class DocumentRag: logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") except Exception as e: logger.warning(f"Failed to save answer to librarian: {e}") - synthesis_doc_id = None # Fall back to inline content + synthesis_doc_id = None - # Generate triples with document reference or inline content syn_triples = set_graph( docrag_synthesis_triples( syn_uri, exp_uri, - answer_text="" if synthesis_doc_id else answer_text, document_id=synthesis_doc_id, ), GRAPH_RETRIEVAL diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 53d4cbf7..22d4fc1b 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -8,20 +8,23 @@ import uuid from collections import OrderedDict from datetime import datetime -from ... schema import IRI, LITERAL +from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE # Provenance imports from trustgraph.provenance import ( question_uri, + grounding_uri as make_grounding_uri, exploration_uri as make_exploration_uri, focus_uri as make_focus_uri, synthesis_uri as make_synthesis_uri, question_triples, + grounding_triples, exploration_triples, focus_triples, synthesis_triples, set_graph, - GRAPH_RETRIEVAL, + GRAPH_RETRIEVAL, GRAPH_SOURCE, + TG_CONTAINS, PROV_WAS_DERIVED_FROM, ) # Module logger @@ -47,6 +50,8 @@ def edge_id(s, p, o): edge_str = f"{s}|{p}|{o}" return hashlib.sha256(edge_str.encode()).hexdigest()[:8] + + class LRUCacheWithTTL: """LRU cache with TTL for label caching @@ -105,42 +110,88 @@ class Query: self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length - async def get_vector(self, query): + async def extract_concepts(self, query): + """Extract key concepts from query for independent embedding.""" + response = await self.rag.prompt_client.prompt( + "extract-concepts", + variables={"query": query} + ) + concepts = [] + if isinstance(response, str): + for line in response.strip().split('\n'): + line = line.strip() + if line: + concepts.append(line) + + if self.verbose: + logger.debug(f"Extracted concepts: {concepts}") + + # Fall back to raw query if extraction returns nothing + return concepts if concepts else [query] + + async def get_vectors(self, concepts): + """Embed multiple concepts concurrently.""" if self.verbose: logger.debug("Computing embeddings...") - qembeds = await self.rag.embeddings_client.embed([query]) + qembeds = await self.rag.embeddings_client.embed(concepts) if self.verbose: logger.debug("Done.") - # Return the vector set for the first (only) text - return qembeds[0] if qembeds else [] + return qembeds async def get_entities(self, query): + """ + Extract concepts from query, embed them, and retrieve matching entities. - vectors = await self.get_vector(query) + Returns: + tuple: (entities, concepts) where entities is a list of entity URI + strings and concepts is the list of concept strings extracted + from the query. + """ + + concepts = await self.extract_concepts(query) + + vectors = await self.get_vectors(concepts) if self.verbose: logger.debug("Getting entities...") - entity_matches = await self.rag.graph_embeddings_client.query( - vector=vectors, limit=self.entity_limit, - user=self.user, collection=self.collection, + # Query entity matches for each concept concurrently + per_concept_limit = max( + 1, self.entity_limit // len(vectors) ) - entities = [ - term_to_string(e.entity) - for e in entity_matches + entity_tasks = [ + self.rag.graph_embeddings_client.query( + vector=v, limit=per_concept_limit, + user=self.user, collection=self.collection, + ) + for v in vectors ] + results = await asyncio.gather(*entity_tasks, return_exceptions=True) + + # Deduplicate while preserving order + seen = set() + entities = [] + for result in results: + if isinstance(result, Exception) or not result: + continue + for e in result: + entity = term_to_string(e.entity) + if entity not in seen: + seen.add(entity) + entities.append(entity) + if self.verbose: logger.debug("Entities:") for ent in entities: logger.debug(f" {ent}") - return entities + return entities, concepts async def maybe_label(self, e): @@ -156,6 +207,7 @@ class Query: res = await self.rag.triples_client.query( s=e, p=LABEL, o=None, limit=1, user=self.user, collection=self.collection, + g="", ) if len(res) == 0: @@ -177,19 +229,19 @@ class Query: s=entity, p=None, o=None, limit=limit_per_entity, user=self.user, collection=self.collection, - batch_size=20, + batch_size=20, g="", ), self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, user=self.user, collection=self.collection, - batch_size=20, + batch_size=20, g="", ), self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, user=self.user, collection=self.collection, - batch_size=20, + batch_size=20, g="", ) ]) @@ -262,8 +314,16 @@ class Query: subgraph.update(batch_result) async def get_subgraph(self, query): + """ + Get subgraph by extracting concepts, finding entities, and traversing. - entities = await self.get_entities(query) + Returns: + tuple: (subgraph, entities, concepts) where subgraph is a list of + (s, p, o) tuples, entities is the seed entity list, and concepts + is the extracted concept list. + """ + + entities, concepts = await self.get_entities(query) if self.verbose: logger.debug("Getting subgraph...") @@ -271,7 +331,7 @@ class Query: # Use optimized batch traversal instead of sequential processing subgraph = await self.follow_edges_batch(entities, self.max_path_length) - return list(subgraph) + return list(subgraph), entities, concepts async def resolve_labels_batch(self, entities): """Resolve labels for multiple entities in parallel""" @@ -286,11 +346,13 @@ class Query: Get subgraph with labels resolved for display. Returns: - tuple: (labeled_edges, uri_map) where: + tuple: (labeled_edges, uri_map, entities, concepts) where: - labeled_edges: list of (label_s, label_p, label_o) tuples - uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o) + - entities: list of seed entity URI strings + - concepts: list of concept strings extracted from query """ - subgraph = await self.get_subgraph(query) + subgraph, entities, concepts = await self.get_subgraph(query) # Filter out label triples filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] @@ -338,8 +400,125 @@ class Query: if self.verbose: logger.debug("Done.") - return labeled_edges, uri_map - + return labeled_edges, uri_map, entities, concepts + + async def trace_source_documents(self, edge_uris): + """ + Trace selected edges back to their source documents via provenance. + + Follows the chain: edge → subgraph (via tg:contains) → chunk → + page → document (via prov:wasDerivedFrom), all in urn:graph:source. + + Args: + edge_uris: List of (s, p, o) URI string tuples + + Returns: + List of unique document titles + """ + # Step 1: Find subgraphs containing these edges via tg:contains + subgraph_tasks = [] + for s, p, o in edge_uris: + quoted = Term( + type=TRIPLE, + triple=SchemaTriple( + s=Term(type=IRI, iri=s), + p=Term(type=IRI, iri=p), + o=Term(type=IRI, iri=o), + ) + ) + subgraph_tasks.append( + self.rag.triples_client.query( + s=None, p=TG_CONTAINS, o=quoted, limit=1, + user=self.user, collection=self.collection, + g=GRAPH_SOURCE, + ) + ) + + subgraph_results = await asyncio.gather( + *subgraph_tasks, return_exceptions=True + ) + + # Collect unique subgraph URIs + subgraph_uris = set() + for result in subgraph_results: + if isinstance(result, Exception) or not result: + continue + for triple in result: + subgraph_uris.add(str(triple.s)) + + if not subgraph_uris: + return [] + + # Step 2: Walk prov:wasDerivedFrom chain to find documents + # Each level: query ?entity prov:wasDerivedFrom ?parent + # Stop when we find entities typed tg:Document + current_uris = subgraph_uris + doc_uris = set() + + for depth in range(4): # Max depth: subgraph → chunk → page → doc + if not current_uris: + break + + derivation_tasks = [ + self.rag.triples_client.query( + s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5, + user=self.user, collection=self.collection, + g=GRAPH_SOURCE, + ) + for uri in current_uris + ] + + derivation_results = await asyncio.gather( + *derivation_tasks, return_exceptions=True + ) + + # URIs with no parent are root documents + next_uris = set() + for uri, result in zip(current_uris, derivation_results): + if isinstance(result, Exception) or not result: + doc_uris.add(uri) + continue + for triple in result: + next_uris.add(str(triple.o)) + + current_uris = next_uris - doc_uris + + if not doc_uris: + return [] + + # Step 3: Get all document metadata properties + # Skip structural predicates that aren't useful context + SKIP_PREDICATES = { + PROV_WAS_DERIVED_FROM, + "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", + } + + metadata_tasks = [ + self.rag.triples_client.query( + s=uri, p=None, o=None, limit=50, + user=self.user, collection=self.collection, + ) + for uri in doc_uris + ] + + metadata_results = await asyncio.gather( + *metadata_tasks, return_exceptions=True + ) + + doc_edges = [] + for result in metadata_results: + if isinstance(result, Exception) or not result: + continue + for triple in result: + p = str(triple.p) + if p in SKIP_PREDICATES: + continue + doc_edges.append(( + str(triple.s), p, str(triple.o) + )) + + return doc_edges + class GraphRag: """ CRITICAL SECURITY: @@ -371,7 +550,8 @@ class GraphRag: async def query( self, query, user = "trustgraph", collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, - max_path_length = 2, streaming = False, chunk_callback = None, + max_path_length = 2, edge_limit = 25, streaming = False, + chunk_callback = None, explain_callback = None, save_answer_callback = None, ): """ @@ -399,6 +579,7 @@ class GraphRag: # Generate explainability URIs upfront session_id = str(uuid.uuid4()) q_uri = question_uri(session_id) + gnd_uri = make_grounding_uri(session_id) exp_uri = make_exploration_uri(session_id) foc_uri = make_focus_uri(session_id) syn_uri = make_synthesis_uri(session_id) @@ -421,12 +602,23 @@ class GraphRag: max_path_length = max_path_length, ) - kg, uri_map = await q.get_labelgraph(query) + kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) + + # Emit grounding explain after concept extraction + if explain_callback: + gnd_triples = set_graph( + grounding_triples(gnd_uri, q_uri, concepts), + GRAPH_RETRIEVAL + ) + await explain_callback(gnd_triples, gnd_uri) # Emit exploration explain after graph retrieval completes if explain_callback: exp_triples = set_graph( - exploration_triples(exp_uri, q_uri, len(kg)), + exploration_triples( + exp_uri, gnd_uri, len(kg), + entities=seed_entities, + ), GRAPH_RETRIEVAL ) await explain_callback(exp_triples, exp_uri) @@ -453,9 +645,9 @@ class GraphRag: if self.verbose: logger.debug(f"Built edge map with {len(edge_map)} edges") - # Step 1: Edge Selection - LLM selects relevant edges with reasoning - selection_response = await self.prompt_client.prompt( - "kg-edge-selection", + # Step 1a: Edge Scoring - LLM scores edges for relevance + scoring_response = await self.prompt_client.prompt( + "kg-edge-scoring", variables={ "query": query, "knowledge": edges_with_ids @@ -463,52 +655,44 @@ class GraphRag: ) if self.verbose: - logger.debug(f"Edge selection response: {selection_response}") + logger.debug(f"Edge scoring response: {scoring_response}") - # Parse response to get selected edge IDs and reasoning - # Response can be a string (JSONL) or a list (JSON array) - selected_ids = set() - selected_edges_with_reasoning = [] # For explain + # Parse scoring response to get edge IDs with scores + scored_edges = [] - if isinstance(selection_response, list): - # JSON array response - for obj in selection_response: - if isinstance(obj, dict) and "id" in obj: - selected_ids.add(obj["id"]) - # Capture original URI edge (not labels) and reasoning for explain - eid = obj["id"] - if eid in uri_map: - # Use original URIs for provenance tracing - uri_s, uri_p, uri_o = uri_map[eid] - selected_edges_with_reasoning.append({ - "edge": (uri_s, uri_p, uri_o), - "reasoning": obj.get("reasoning", ""), - }) - elif isinstance(selection_response, str): - # JSONL string response - for line in selection_response.strip().split('\n'): + def parse_scored_edge(obj): + if isinstance(obj, dict) and "id" in obj and "score" in obj: + try: + score = int(obj["score"]) + except (ValueError, TypeError): + score = 0 + scored_edges.append({"id": obj["id"], "score": score}) + + if isinstance(scoring_response, list): + for obj in scoring_response: + parse_scored_edge(obj) + elif isinstance(scoring_response, str): + for line in scoring_response.strip().split('\n'): line = line.strip() if not line: continue try: - obj = json.loads(line) - if "id" in obj: - selected_ids.add(obj["id"]) - # Capture original URI edge (not labels) and reasoning for explain - eid = obj["id"] - if eid in uri_map: - # Use original URIs for provenance tracing - uri_s, uri_p, uri_o = uri_map[eid] - selected_edges_with_reasoning.append({ - "edge": (uri_s, uri_p, uri_o), - "reasoning": obj.get("reasoning", ""), - }) + parse_scored_edge(json.loads(line)) except json.JSONDecodeError: - logger.warning(f"Failed to parse edge selection line: {line}") - continue + logger.warning( + f"Failed to parse edge scoring line: {line}" + ) + + # Select top N edges by score + scored_edges.sort(key=lambda x: x["score"], reverse=True) + top_edges = scored_edges[:edge_limit] + selected_ids = {e["id"] for e in top_edges} if self.verbose: - logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}") + logger.debug( + f"Scored {len(scored_edges)} edges, " + f"selected top {len(selected_ids)}" + ) # Filter to selected edges selected_edges = [] @@ -516,6 +700,82 @@ class GraphRag: if eid in edge_map: selected_edges.append(edge_map[eid]) + # Step 1b: Edge Reasoning + Document Tracing (concurrent) + selected_edges_with_ids = [ + {"id": eid, "s": s, "p": p, "o": o} + for eid in selected_ids + if eid in edge_map + for s, p, o in [edge_map[eid]] + ] + + # Collect selected edge URIs for document tracing + selected_edge_uris = [ + uri_map[eid] + for eid in selected_ids + if eid in uri_map + ] + + # Run reasoning and document tracing concurrently + reasoning_task = self.prompt_client.prompt( + "kg-edge-reasoning", + variables={ + "query": query, + "knowledge": selected_edges_with_ids + } + ) + doc_trace_task = q.trace_source_documents(selected_edge_uris) + + reasoning_response, source_documents = await asyncio.gather( + reasoning_task, doc_trace_task, return_exceptions=True + ) + + # Handle exceptions from gather + if isinstance(reasoning_response, Exception): + logger.warning( + f"Edge reasoning failed: {reasoning_response}" + ) + reasoning_response = "" + if isinstance(source_documents, Exception): + logger.warning( + f"Document tracing failed: {source_documents}" + ) + source_documents = [] + + + if self.verbose: + logger.debug(f"Edge reasoning response: {reasoning_response}") + + # Parse reasoning response and build explainability data + reasoning_map = {} + + def parse_reasoning(obj): + if isinstance(obj, dict) and "id" in obj: + reasoning_map[obj["id"]] = obj.get("reasoning", "") + + if isinstance(reasoning_response, list): + for obj in reasoning_response: + parse_reasoning(obj) + elif isinstance(reasoning_response, str): + for line in reasoning_response.strip().split('\n'): + line = line.strip() + if not line: + continue + try: + parse_reasoning(json.loads(line)) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge reasoning line: {line}" + ) + + selected_edges_with_reasoning = [] + for eid in selected_ids: + if eid in uri_map: + uri_s, uri_p, uri_o = uri_map[eid] + selected_edges_with_reasoning.append({ + "edge": (uri_s, uri_p, uri_o), + "reasoning": reasoning_map.get(eid, ""), + }) + if self.verbose: logger.debug(f"Filtered to {len(selected_edges)} edges") @@ -534,6 +794,18 @@ class GraphRag: {"s": s, "p": p, "o": o} for s, p, o in selected_edges ] + + # Add source document metadata as knowledge edges + for s, p, o in source_documents: + selected_edge_dicts.append({ + "s": s, "p": p, "o": o, + }) + + synthesis_variables = { + "query": query, + "knowledge": selected_edge_dicts, + } + if streaming and chunk_callback: # Accumulate chunks for answer storage while forwarding to callback accumulated_chunks = [] @@ -544,10 +816,7 @@ class GraphRag: await self.prompt_client.prompt( "kg-synthesis", - variables={ - "query": query, - "knowledge": selected_edge_dicts - }, + variables=synthesis_variables, streaming=True, chunk_callback=accumulating_callback ) @@ -556,10 +825,7 @@ class GraphRag: else: resp = await self.prompt_client.prompt( "kg-synthesis", - variables={ - "query": query, - "knowledge": selected_edge_dicts - } + variables=synthesis_variables, ) if self.verbose: @@ -570,9 +836,8 @@ class GraphRag: synthesis_doc_id = None answer_text = resp if resp else "" - # Save answer to librarian if callback provided + # Save answer to librarian if save_answer_callback and answer_text: - # Generate document ID as URN matching query-time provenance format synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}" try: await save_answer_callback(synthesis_doc_id, answer_text) @@ -580,13 +845,11 @@ class GraphRag: logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") except Exception as e: logger.warning(f"Failed to save answer to librarian: {e}") - synthesis_doc_id = None # Fall back to inline content + synthesis_doc_id = None - # Generate triples with document reference or inline content syn_triples = set_graph( synthesis_triples( syn_uri, foc_uri, - answer_text="" if synthesis_doc_id else answer_text, document_id=synthesis_doc_id, ), GRAPH_RETRIEVAL diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 81be2819..ec4a806c 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -39,6 +39,7 @@ class Processor(FlowProcessor): triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) + edge_limit = params.get("edge_limit", 25) super(Processor, self).__init__( **params | { @@ -48,6 +49,7 @@ class Processor(FlowProcessor): "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, + "edge_limit": edge_limit, } ) @@ -55,6 +57,7 @@ class Processor(FlowProcessor): self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length + self.default_edge_limit = edge_limit # CRITICAL SECURITY: NEVER share data between users or collections # Each user/collection combination MUST have isolated data access @@ -292,6 +295,11 @@ class Processor(FlowProcessor): else: max_path_length = self.default_max_path_length + if v.edge_limit: + edge_limit = v.edge_limit + else: + edge_limit = self.default_edge_limit + # Callback to save answer content to librarian async def save_answer(doc_id, answer_text): await self.save_answer_content( @@ -322,6 +330,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_limit = edge_limit, streaming = True, chunk_callback = send_chunk, explain_callback = send_explainability, @@ -335,6 +344,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_limit = edge_limit, explain_callback = send_explainability, save_answer_callback = save_answer, )