diff --git a/dev-tools/tests/triples/load_test_triples.py b/dev-tools/tests/triples/load_test_triples.py new file mode 100755 index 00000000..a147d041 --- /dev/null +++ b/dev-tools/tests/triples/load_test_triples.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Load test triples into the triple store for testing tg-query-graph. + +Tests all graph features: +- SPO with IRI objects +- SPO with literal objects +- Literals with XML datatypes +- Literals with language tags +- Quoted triples (RDF-star) +- Named graphs +""" + +import asyncio +import json +import os +import websockets + +# Configuration +API_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None) +FLOW = "default" +USER = "trustgraph" +COLLECTION = "default" +DOCUMENT_ID = "test-triples-001" + +# Namespaces +EX = "http://example.org/" +RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#" +RDFS = "http://www.w3.org/2000/01/rdf-schema#" +XSD = "http://www.w3.org/2001/XMLSchema#" +TG = "https://trustgraph.ai/ns/" + + +def iri(value): + """Build IRI term.""" + return {"t": "i", "i": value} + + +def literal(value, datatype=None, language=None): + """Build literal term with optional datatype or language.""" + term = {"t": "l", "v": value} + if datatype: + term["dt"] = datatype + if language: + term["ln"] = language + return term + + +def quoted_triple(s, p, o): + """Build quoted triple term (RDF-star).""" + return { + "t": "t", + "tr": {"s": s, "p": p, "o": o} + } + + +def triple(s, p, o, g=None): + """Build a complete triple dict.""" + t = {"s": s, "p": p, "o": o} + if g: + t["g"] = g + return t + + +# Test triples covering all features +TEST_TRIPLES = [ + # 1. Basic SPO with IRI object + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDF}type"), + iri(f"{EX}Scientist") + ), + + # 2. SPO with IRI object (relationship) + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}discovered"), + iri(f"{EX}radium") + ), + + # 3. Simple literal (no datatype/language) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie") + ), + + # 4. Literal with language tag (English) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie", language="en") + ), + + # 5. Literal with language tag (French) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie", language="fr") + ), + + # 6. Literal with language tag (Polish) + triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Maria Sk\u0142odowska-Curie", language="pl") + ), + + # 7. Literal with xsd:integer datatype + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}birthYear"), + literal("1867", datatype=f"{XSD}integer") + ), + + # 8. Literal with xsd:date datatype + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}birthDate"), + literal("1867-11-07", datatype=f"{XSD}date") + ), + + # 9. Literal with xsd:boolean datatype + triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}nobelLaureate"), + literal("true", datatype=f"{XSD}boolean") + ), + + # 10. Quoted triple in object position (RDF 1.2 style) + # "Wikipedia asserts that Marie Curie discovered radium" + triple( + iri(f"{EX}wikipedia"), + iri(f"{TG}asserts"), + quoted_triple( + iri(f"{EX}marie-curie"), + iri(f"{EX}discovered"), + iri(f"{EX}radium") + ) + ), + + # 11. Quoted triple with literal inside (object position) + # "NLP-v1.0 extracted that Marie Curie has label Marie Curie" + triple( + iri(f"{EX}nlp-v1"), + iri(f"{TG}extracted"), + quoted_triple( + iri(f"{EX}marie-curie"), + iri(f"{RDFS}label"), + literal("Marie Curie") + ) + ), + + # 12. Triple in a named graph (g is plain string, not Term) + triple( + iri(f"{EX}radium"), + iri(f"{RDF}type"), + iri(f"{EX}Element"), + g=f"{EX}chemistry-graph" + ), + + # 13. Another triple in the same named graph + triple( + iri(f"{EX}radium"), + iri(f"{EX}atomicNumber"), + literal("88", datatype=f"{XSD}integer"), + g=f"{EX}chemistry-graph" + ), + + # 14. Triple in a different named graph + triple( + iri(f"{EX}pierre-curie"), + iri(f"{EX}spouseOf"), + iri(f"{EX}marie-curie"), + g=f"{EX}biography-graph" + ), +] + + +async def load_triples(): + """Load test triples via WebSocket bulk import.""" + # Convert HTTP URL to WebSocket URL + ws_url = API_URL.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url.rstrip('/')}/api/v1/flow/{FLOW}/import/triples" + if TOKEN: + ws_url = f"{ws_url}?token={TOKEN}" + + metadata = { + "id": DOCUMENT_ID, + "metadata": [], + "user": USER, + "collection": COLLECTION + } + + print(f"Connecting to {ws_url}...") + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as websocket: + message = { + "metadata": metadata, + "triples": TEST_TRIPLES + } + print(f"Sending {len(TEST_TRIPLES)} test triples...") + await websocket.send(json.dumps(message)) + print("Triples sent successfully!") + + print("\nTest triples loaded:") + print(" - 2 basic IRI triples (type, relationship)") + print(" - 4 literal triples (plain + 3 languages: en, fr, pl)") + print(" - 3 typed literal triples (xsd:integer, xsd:date, xsd:boolean)") + print(" - 2 quoted triples (RDF-star provenance)") + print(" - 3 triples in named graphs (chemistry-graph, biography-graph)") + print(f"\nTotal: {len(TEST_TRIPLES)} triples") + print(f"User: {USER}, Collection: {COLLECTION}") + + +def main(): + print("Loading test triples for tg-query-graph testing\n") + asyncio.run(load_triples()) + print("\nDone! Now test with:") + print(" tg-query-graph -s http://example.org/marie-curie") + print(" tg-query-graph -p http://www.w3.org/2000/01/rdf-schema#label") + print(" tg-query-graph -o 'Marie Curie' --object-language en") + print(" tg-query-graph --format json | jq .") + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_agent/test_callback_message_id.py b/tests/unit/test_agent/test_callback_message_id.py new file mode 100644 index 00000000..7cb0ee54 --- /dev/null +++ b/tests/unit/test_agent/test_callback_message_id.py @@ -0,0 +1,122 @@ +""" +Tests that streaming callbacks set message_id on AgentResponse. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.agent.orchestrator.pattern_base import PatternBase +from trustgraph.schema import AgentResponse + + +@pytest.fixture +def pattern(): + processor = MagicMock() + return PatternBase(processor) + + +class TestThinkCallbackMessageId: + + @pytest.mark.asyncio + async def test_streaming_think_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/i1/thought" + think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id) + await think("hello", is_final=False) + + assert len(responses) == 1 + assert responses[0].message_id == msg_id + assert responses[0].chunk_type == "thought" + + @pytest.mark.asyncio + async def test_non_streaming_think_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/i1/thought" + think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id) + await think("hello") + + assert responses[0].message_id == msg_id + assert responses[0].end_of_message is True + + +class TestObserveCallbackMessageId: + + @pytest.mark.asyncio + async def test_streaming_observe_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/i1/observation" + observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id) + await observe("result", is_final=True) + + assert responses[0].message_id == msg_id + assert responses[0].chunk_type == "observation" + + +class TestAnswerCallbackMessageId: + + @pytest.mark.asyncio + async def test_streaming_answer_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/final" + answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id) + await answer("the answer") + + assert responses[0].message_id == msg_id + assert responses[0].chunk_type == "answer" + + @pytest.mark.asyncio + async def test_no_message_id_default(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + answer = pattern.make_answer_callback(capture, streaming=True) + await answer("the answer") + + assert responses[0].message_id == "" + + +class TestSendFinalResponseMessageId: + + @pytest.mark.asyncio + async def test_streaming_final_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/final" + await pattern.send_final_response( + capture, streaming=True, answer_text="answer", + message_id=msg_id, + ) + + # Should get content chunk + end-of-dialog marker + assert all(r.message_id == msg_id for r in responses) + + @pytest.mark.asyncio + async def test_non_streaming_final_has_message_id(self, pattern): + responses = [] + async def capture(r): + responses.append(r) + + msg_id = "urn:trustgraph:agent:sess/final" + await pattern.send_final_response( + capture, streaming=False, answer_text="answer", + message_id=msg_id, + ) + + assert len(responses) == 1 + assert responses[0].message_id == msg_id + assert responses[0].end_of_dialog is True diff --git a/tests/unit/test_agent/test_explainability_parsing.py b/tests/unit/test_agent/test_explainability_parsing.py index 7035318d..d75ea604 100644 --- a/tests/unit/test_agent/test_explainability_parsing.py +++ b/tests/unit/test_agent/test_explainability_parsing.py @@ -22,6 +22,7 @@ from trustgraph.api.explainability import ( TG_SYNTHESIS, TG_ANSWER_TYPE, TG_OBSERVATION_TYPE, + TG_TOOL_USE, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, @@ -76,6 +77,13 @@ class TestFromTriplesDispatch: entity = ExplainEntity.from_triples("urn:a", triples) assert isinstance(entity, Analysis) + def test_dispatches_analysis_with_tooluse(self): + """Analysis+ToolUse mixin still dispatches to Analysis.""" + triples = _make_triples("urn:a", + [PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE]) + entity = ExplainEntity.from_triples("urn:a", triples) + assert isinstance(entity, Analysis) + def test_dispatches_observation(self): triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE]) entity = ExplainEntity.from_triples("urn:o", triples) diff --git a/tests/unit/test_agent/test_on_action_callback.py b/tests/unit/test_agent/test_on_action_callback.py new file mode 100644 index 00000000..4a1c0c3b --- /dev/null +++ b/tests/unit/test_agent/test_on_action_callback.py @@ -0,0 +1,132 @@ +""" +Tests for the on_action callback in react() — verifies that it fires +after action selection but before tool execution. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from trustgraph.agent.react.agent_manager import AgentManager +from trustgraph.agent.react.types import Action, Final, Tool, Argument + + +class TestOnActionCallback: + + @pytest.mark.asyncio + async def test_on_action_called_for_tool_use(self): + """on_action fires when react() selects a tool (not Final).""" + call_log = [] + + async def fake_on_action(act): + call_log.append(("on_action", act.name)) + + # Tool that records when it's invoked + async def tool_invoke(**kwargs): + call_log.append(("tool_invoke",)) + return "tool result" + + tool_impl = MagicMock() + tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke) + + tools = { + "search": Tool( + name="search", + description="Search", + implementation=tool_impl, + arguments=[Argument(name="query", type="string", description="q")], + config={}, + ), + } + + agent = AgentManager(tools=tools) + + # Mock reason() to return an Action + action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="") + agent.reason = AsyncMock(return_value=action) + + think = AsyncMock() + observe = AsyncMock() + context = MagicMock() + + await agent.react( + question="test", + history=[], + think=think, + observe=observe, + context=context, + on_action=fake_on_action, + ) + + # on_action should fire before tool_invoke + assert len(call_log) == 2 + assert call_log[0] == ("on_action", "search") + assert call_log[1] == ("tool_invoke",) + + @pytest.mark.asyncio + async def test_on_action_not_called_for_final(self): + """on_action does not fire when react() returns Final.""" + called = [] + + async def fake_on_action(act): + called.append(act) + + agent = AgentManager(tools={}) + agent.reason = AsyncMock( + return_value=Final(thought="done", final="answer") + ) + + think = AsyncMock() + observe = AsyncMock() + context = MagicMock() + + result = await agent.react( + question="test", + history=[], + think=think, + observe=observe, + context=context, + on_action=fake_on_action, + ) + + assert isinstance(result, Final) + assert len(called) == 0 + + @pytest.mark.asyncio + async def test_on_action_none_accepted(self): + """react() works fine when on_action is None (default).""" + async def tool_invoke(**kwargs): + return "result" + + tool_impl = MagicMock() + tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke) + + tools = { + "search": Tool( + name="search", + description="Search", + implementation=tool_impl, + arguments=[], + config={}, + ), + } + + agent = AgentManager(tools=tools) + agent.reason = AsyncMock( + return_value=Action(thought="t", name="search", arguments={}, observation="") + ) + + think = AsyncMock() + observe = AsyncMock() + context = MagicMock() + + result = await agent.react( + question="test", + history=[], + think=think, + observe=observe, + context=context, + # on_action not passed — defaults to None + ) + + assert isinstance(result, Action) + assert result.observation == "result" diff --git a/tests/unit/test_agent/test_parse_chunk_message_id.py b/tests/unit/test_agent/test_parse_chunk_message_id.py new file mode 100644 index 00000000..38942f1e --- /dev/null +++ b/tests/unit/test_agent/test_parse_chunk_message_id.py @@ -0,0 +1,74 @@ +""" +Tests that _parse_chunk propagates message_id from wire format +to AgentThought, AgentObservation, and AgentAnswer. +""" + +import pytest + +from trustgraph.api.socket_client import SocketClient +from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer + + +@pytest.fixture +def client(): + # We only need _parse_chunk — don't connect + c = object.__new__(SocketClient) + return c + + +class TestParseChunkMessageId: + + def test_thought_message_id(self, client): + resp = { + "chunk_type": "thought", + "content": "thinking...", + "end_of_message": False, + "message_id": "urn:trustgraph:agent:sess/i1/thought", + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentThought) + assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought" + + def test_observation_message_id(self, client): + resp = { + "chunk_type": "observation", + "content": "result", + "end_of_message": True, + "message_id": "urn:trustgraph:agent:sess/i1/observation", + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentObservation) + assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation" + + def test_answer_message_id(self, client): + resp = { + "chunk_type": "answer", + "content": "the answer", + "end_of_message": False, + "end_of_dialog": False, + "message_id": "urn:trustgraph:agent:sess/final", + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentAnswer) + assert chunk.message_id == "urn:trustgraph:agent:sess/final" + + def test_thought_missing_message_id(self, client): + resp = { + "chunk_type": "thought", + "content": "thinking...", + "end_of_message": False, + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentThought) + assert chunk.message_id == "" + + def test_answer_missing_message_id(self, client): + resp = { + "chunk_type": "answer", + "content": "answer", + "end_of_message": True, + "end_of_dialog": True, + } + chunk = client._parse_chunk(resp) + assert isinstance(chunk, AgentAnswer) + assert chunk.message_id == "" diff --git a/tests/unit/test_provenance/test_agent_provenance.py b/tests/unit/test_provenance/test_agent_provenance.py index d3f0ef8c..c548ef9d 100644 --- a/tests/unit/test_provenance/test_agent_provenance.py +++ b/tests/unit/test_provenance/test_agent_provenance.py @@ -12,6 +12,7 @@ from trustgraph.provenance.agent import ( agent_iteration_triples, agent_observation_triples, agent_final_triples, + agent_synthesis_triples, ) from trustgraph.provenance.namespaces import ( @@ -21,7 +22,7 @@ from trustgraph.provenance.namespaces import ( TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT, TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, - TG_TOOL_USE, + TG_TOOL_USE, TG_SYNTHESIS, TG_AGENT_QUESTION, ) @@ -105,6 +106,25 @@ class TestAgentSessionTriples: ) assert len(triples) == 6 + def test_session_parent_uri(self): + """Subagent sessions derive from a parent entity (e.g. Decomposition).""" + parent = "urn:trustgraph:agent:parent/decompose" + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z", + parent_uri=parent, + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI) + assert derived is not None + assert derived.o.iri == parent + + def test_session_no_parent_uri(self): + """Top-level sessions have no wasDerivedFrom.""" + triples = agent_session_triples( + self.SESSION_URI, "Q", "2024-01-01T00:00:00Z" + ) + derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI) + assert derived is None + # --------------------------------------------------------------------------- # agent_iteration_triples @@ -358,3 +378,59 @@ class TestAgentFinalTriples: ) doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI) assert doc is None + + +# --------------------------------------------------------------------------- +# agent_synthesis_triples +# --------------------------------------------------------------------------- + +class TestAgentSynthesisTriples: + + SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis" + FINDING_0 = "urn:trustgraph:agent:test-session/finding/0" + FINDING_1 = "urn:trustgraph:agent:test-session/finding/1" + FINDING_2 = "urn:trustgraph:agent:test-session/finding/2" + + def test_synthesis_types(self): + triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0) + assert has_type(triples, self.SYNTH_URI, PROV_ENTITY) + assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS) + assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE) + + def test_synthesis_single_parent_string(self): + """Single parent passed as string.""" + triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0) + derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI) + assert len(derived) == 1 + assert derived[0].o.iri == self.FINDING_0 + + def test_synthesis_multiple_parents(self): + """Multiple parents for supervisor fan-in.""" + parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2] + triples = agent_synthesis_triples(self.SYNTH_URI, parents) + derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI) + assert len(derived) == 3 + derived_uris = {t.o.iri for t in derived} + assert derived_uris == set(parents) + + def test_synthesis_single_parent_as_list(self): + """Single parent passed as list.""" + triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0]) + derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI) + assert len(derived) == 1 + assert derived[0].o.iri == self.FINDING_0 + + def test_synthesis_document(self): + triples = agent_synthesis_triples( + self.SYNTH_URI, self.FINDING_0, + document_id="urn:doc:synth", + ) + doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI) + assert doc is not None + assert doc.o.iri == "urn:doc:synth" + + def test_synthesis_label(self): + triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0) + label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI) + assert label is not None + assert label.o.value == "Synthesis" diff --git a/tests/unit/test_provenance/test_explainability.py b/tests/unit/test_provenance/test_explainability.py index e2c7fcd1..a6d655a7 100644 --- a/tests/unit/test_provenance/test_explainability.py +++ b/tests/unit/test_provenance/test_explainability.py @@ -558,3 +558,96 @@ class TestExplainabilityClientDetectSessionType: mock_flow = MagicMock() client = ExplainabilityClient(mock_flow, retry_delay=0.0) assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag" + + +class TestChainWalkerFollowsSubTraceTerminal: + """Test that _follow_provenance_chain continues from a sub-trace's + Synthesis to find downstream entities like Observation.""" + + def test_observation_found_via_subtrace_synthesis(self): + """ + DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation + The walker should find Analysis, the sub-trace, then follow from + Synthesis to discover Observation. + """ + # Entity triples (s, p, o) + entity_data = { + "urn:agent:q": [ + ("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION), + ("urn:agent:q", TG_QUERY, "test"), + ], + "urn:agent:analysis": [ + ("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS), + ("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"), + ], + "urn:graphrag:q": [ + ("urn:graphrag:q", RDF_TYPE, TG_QUESTION), + ("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION), + ("urn:graphrag:q", TG_QUERY, "test"), + ("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"), + ], + "urn:graphrag:synth": [ + ("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS), + ("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"), + ], + "urn:agent:obs": [ + ("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE), + ("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"), + ], + "urn:agent:conclusion": [ + ("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION), + ("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"), + ], + } + + # Build a mock flow that answers triples queries + # Query by s= returns that entity's triples + # Query by p=wasDerivedFrom, o=X returns entities derived from X + def mock_triples_query(s=None, p=None, o=None, **kwargs): + if s and not p: + # Fetch entity triples + tuples = entity_data.get(s, []) + return _make_wire_triples(tuples) + elif p == PROV_WAS_DERIVED_FROM and o: + # Find entities derived from o + results = [] + for uri, tuples in entity_data.items(): + for _, pred, obj in tuples: + if pred == PROV_WAS_DERIVED_FROM and obj == o: + results.append((uri, pred, obj)) + return _make_wire_triples(results) + return [] + + mock_flow = MagicMock() + mock_flow.triples_query.side_effect = mock_triples_query + + client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2) + + # Mock fetch_graphrag_trace to return a trace with a synthesis + synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis") + client.fetch_graphrag_trace = MagicMock(return_value={ + "question": Question(uri="urn:graphrag:q", entity_type="question", + question_type="graph-rag"), + "synthesis": synth_entity, + }) + + trace = client.fetch_agent_trace( + "urn:agent:q", + graph="urn:graph:retrieval", + ) + + # Should have found all steps + step_types = [ + type(s).__name__ if not isinstance(s, dict) else s.get("type") + for s in trace["steps"] + ] + + assert "Analysis" in step_types, f"Missing Analysis in {step_types}" + assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}" + assert "Observation" in step_types, f"Missing Observation in {step_types}" + assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}" + + # Observation should come after the sub-trace + subtrace_idx = step_types.index("sub-trace") + obs_idx = step_types.index("Observation") + assert obs_idx > subtrace_idx, "Observation should appear after sub-trace"