mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 01:46:22 +02:00
Additional agent DAG tests (#750)
- test_agent_provenance.py: test_session_parent_uri, test_session_no_parent_uri, and 6 synthesis tests (types, single/multiple parents, document, label) - test_on_action_callback.py: 3 tests — fires before tool, skipped for Final, works when None - test_callback_message_id.py: 7 tests — message_id on think/observe/ answer callbacks (streaming + non-streaming) and send_final_response - test_parse_chunk_message_id.py (5 tests) - _parse_chunk propagates message_id for thought, observation, answer; handles missing gracefully - test_explainability_parsing.py (+1) - test_dispatches_analysis_with_tooluse - Analysis+ToolUse mixin still dispatches to Analysis - test_explainability.py (+1) - test_observation_found_via_subtrace_synthesis - chain walker follows from sub-trace Synthesis to find Observation and Conclusion in correct order - test_agent_provenance.py (+8) - session parent_uri (2), synthesis single/multiple parents, types, document, label (6)
This commit is contained in:
parent
3ba6a3238f
commit
dbf8daa74a
7 changed files with 733 additions and 1 deletions
122
tests/unit/test_agent/test_callback_message_id.py
Normal file
122
tests/unit/test_agent/test_callback_message_id.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
Tests that streaming callbacks set message_id on AgentResponse.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.orchestrator.pattern_base import PatternBase
|
||||
from trustgraph.schema import AgentResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pattern():
|
||||
processor = MagicMock()
|
||||
return PatternBase(processor)
|
||||
|
||||
|
||||
class TestThinkCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_think_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/thought"
|
||||
think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id)
|
||||
await think("hello", is_final=False)
|
||||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "thought"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_think_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/thought"
|
||||
think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id)
|
||||
await think("hello")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].end_of_message is True
|
||||
|
||||
|
||||
class TestObserveCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_observe_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/i1/observation"
|
||||
observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id)
|
||||
await observe("result", is_final=True)
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "observation"
|
||||
|
||||
|
||||
class TestAnswerCallbackMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_answer_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id)
|
||||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_default(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
answer = pattern.make_answer_callback(capture, streaming=True)
|
||||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == ""
|
||||
|
||||
|
||||
class TestSendFinalResponseMessageId:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_final_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
await pattern.send_final_response(
|
||||
capture, streaming=True, answer_text="answer",
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
# Should get content chunk + end-of-dialog marker
|
||||
assert all(r.message_id == msg_id for r in responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_final_has_message_id(self, pattern):
|
||||
responses = []
|
||||
async def capture(r):
|
||||
responses.append(r)
|
||||
|
||||
msg_id = "urn:trustgraph:agent:sess/final"
|
||||
await pattern.send_final_response(
|
||||
capture, streaming=False, answer_text="answer",
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].end_of_dialog is True
|
||||
|
|
@ -22,6 +22,7 @@ from trustgraph.api.explainability import (
|
|||
TG_SYNTHESIS,
|
||||
TG_ANSWER_TYPE,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_TOOL_USE,
|
||||
TG_ANALYSIS,
|
||||
TG_CONCLUSION,
|
||||
TG_DOCUMENT,
|
||||
|
|
@ -76,6 +77,13 @@ class TestFromTriplesDispatch:
|
|||
entity = ExplainEntity.from_triples("urn:a", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
|
||||
def test_dispatches_analysis_with_tooluse(self):
|
||||
"""Analysis+ToolUse mixin still dispatches to Analysis."""
|
||||
triples = _make_triples("urn:a",
|
||||
[PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE])
|
||||
entity = ExplainEntity.from_triples("urn:a", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
|
||||
def test_dispatches_observation(self):
|
||||
triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE])
|
||||
entity = ExplainEntity.from_triples("urn:o", triples)
|
||||
|
|
|
|||
132
tests/unit/test_agent/test_on_action_callback.py
Normal file
132
tests/unit/test_agent/test_on_action_callback.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Tests for the on_action callback in react() — verifies that it fires
|
||||
after action selection but before tool execution.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
|
||||
|
||||
class TestOnActionCallback:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_called_for_tool_use(self):
|
||||
"""on_action fires when react() selects a tool (not Final)."""
|
||||
call_log = []
|
||||
|
||||
async def fake_on_action(act):
|
||||
call_log.append(("on_action", act.name))
|
||||
|
||||
# Tool that records when it's invoked
|
||||
async def tool_invoke(**kwargs):
|
||||
call_log.append(("tool_invoke",))
|
||||
return "tool result"
|
||||
|
||||
tool_impl = MagicMock()
|
||||
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||||
|
||||
tools = {
|
||||
"search": Tool(
|
||||
name="search",
|
||||
description="Search",
|
||||
implementation=tool_impl,
|
||||
arguments=[Argument(name="query", type="string", description="q")],
|
||||
config={},
|
||||
),
|
||||
}
|
||||
|
||||
agent = AgentManager(tools=tools)
|
||||
|
||||
# Mock reason() to return an Action
|
||||
action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="")
|
||||
agent.reason = AsyncMock(return_value=action)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
on_action=fake_on_action,
|
||||
)
|
||||
|
||||
# on_action should fire before tool_invoke
|
||||
assert len(call_log) == 2
|
||||
assert call_log[0] == ("on_action", "search")
|
||||
assert call_log[1] == ("tool_invoke",)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_not_called_for_final(self):
|
||||
"""on_action does not fire when react() returns Final."""
|
||||
called = []
|
||||
|
||||
async def fake_on_action(act):
|
||||
called.append(act)
|
||||
|
||||
agent = AgentManager(tools={})
|
||||
agent.reason = AsyncMock(
|
||||
return_value=Final(thought="done", final="answer")
|
||||
)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
result = await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
on_action=fake_on_action,
|
||||
)
|
||||
|
||||
assert isinstance(result, Final)
|
||||
assert len(called) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_action_none_accepted(self):
|
||||
"""react() works fine when on_action is None (default)."""
|
||||
async def tool_invoke(**kwargs):
|
||||
return "result"
|
||||
|
||||
tool_impl = MagicMock()
|
||||
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
|
||||
|
||||
tools = {
|
||||
"search": Tool(
|
||||
name="search",
|
||||
description="Search",
|
||||
implementation=tool_impl,
|
||||
arguments=[],
|
||||
config={},
|
||||
),
|
||||
}
|
||||
|
||||
agent = AgentManager(tools=tools)
|
||||
agent.reason = AsyncMock(
|
||||
return_value=Action(thought="t", name="search", arguments={}, observation="")
|
||||
)
|
||||
|
||||
think = AsyncMock()
|
||||
observe = AsyncMock()
|
||||
context = MagicMock()
|
||||
|
||||
result = await agent.react(
|
||||
question="test",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=observe,
|
||||
context=context,
|
||||
# on_action not passed — defaults to None
|
||||
)
|
||||
|
||||
assert isinstance(result, Action)
|
||||
assert result.observation == "result"
|
||||
74
tests/unit/test_agent/test_parse_chunk_message_id.py
Normal file
74
tests/unit/test_agent/test_parse_chunk_message_id.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Tests that _parse_chunk propagates message_id from wire format
|
||||
to AgentThought, AgentObservation, and AgentAnswer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.api.socket_client import SocketClient
|
||||
from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# We only need _parse_chunk — don't connect
|
||||
c = object.__new__(SocketClient)
|
||||
return c
|
||||
|
||||
|
||||
class TestParseChunkMessageId:
|
||||
|
||||
def test_thought_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/thought",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentThought)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought"
|
||||
|
||||
def test_observation_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "observation",
|
||||
"content": "result",
|
||||
"end_of_message": True,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/observation",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentObservation)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation"
|
||||
|
||||
def test_answer_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"content": "the answer",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/final",
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentAnswer)
|
||||
assert chunk.message_id == "urn:trustgraph:agent:sess/final"
|
||||
|
||||
def test_thought_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentThought)
|
||||
assert chunk.message_id == ""
|
||||
|
||||
def test_answer_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"content": "answer",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
}
|
||||
chunk = client._parse_chunk(resp)
|
||||
assert isinstance(chunk, AgentAnswer)
|
||||
assert chunk.message_id == ""
|
||||
|
|
@ -12,6 +12,7 @@ from trustgraph.provenance.agent import (
|
|||
agent_iteration_triples,
|
||||
agent_observation_triples,
|
||||
agent_final_triples,
|
||||
agent_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
|
|
@ -21,7 +22,7 @@ from trustgraph.provenance.namespaces import (
|
|||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_TOOL_USE,
|
||||
TG_TOOL_USE, TG_SYNTHESIS,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
|
@ -105,6 +106,25 @@ class TestAgentSessionTriples:
|
|||
)
|
||||
assert len(triples) == 6
|
||||
|
||||
def test_session_parent_uri(self):
|
||||
"""Subagent sessions derive from a parent entity (e.g. Decomposition)."""
|
||||
parent = "urn:trustgraph:agent:parent/decompose"
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z",
|
||||
parent_uri=parent,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == parent
|
||||
|
||||
def test_session_no_parent_uri(self):
|
||||
"""Top-level sessions have no wasDerivedFrom."""
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
|
||||
assert derived is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_iteration_triples
|
||||
|
|
@ -358,3 +378,59 @@ class TestAgentFinalTriples:
|
|||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert doc is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_synthesis_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentSynthesisTriples:
|
||||
|
||||
SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis"
|
||||
FINDING_0 = "urn:trustgraph:agent:test-session/finding/0"
|
||||
FINDING_1 = "urn:trustgraph:agent:test-session/finding/1"
|
||||
FINDING_2 = "urn:trustgraph:agent:test-session/finding/2"
|
||||
|
||||
def test_synthesis_types(self):
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
assert has_type(triples, self.SYNTH_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS)
|
||||
assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_synthesis_single_parent_string(self):
|
||||
"""Single parent passed as string."""
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 1
|
||||
assert derived[0].o.iri == self.FINDING_0
|
||||
|
||||
def test_synthesis_multiple_parents(self):
|
||||
"""Multiple parents for supervisor fan-in."""
|
||||
parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2]
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, parents)
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 3
|
||||
derived_uris = {t.o.iri for t in derived}
|
||||
assert derived_uris == set(parents)
|
||||
|
||||
def test_synthesis_single_parent_as_list(self):
|
||||
"""Single parent passed as list."""
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0])
|
||||
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
|
||||
assert len(derived) == 1
|
||||
assert derived[0].o.iri == self.FINDING_0
|
||||
|
||||
def test_synthesis_document(self):
|
||||
triples = agent_synthesis_triples(
|
||||
self.SYNTH_URI, self.FINDING_0,
|
||||
document_id="urn:doc:synth",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == "urn:doc:synth"
|
||||
|
||||
def test_synthesis_label(self):
|
||||
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
|
||||
label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Synthesis"
|
||||
|
|
|
|||
|
|
@ -558,3 +558,96 @@ class TestExplainabilityClientDetectSessionType:
|
|||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
|
||||
|
||||
|
||||
class TestChainWalkerFollowsSubTraceTerminal:
|
||||
"""Test that _follow_provenance_chain continues from a sub-trace's
|
||||
Synthesis to find downstream entities like Observation."""
|
||||
|
||||
def test_observation_found_via_subtrace_synthesis(self):
|
||||
"""
|
||||
DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation
|
||||
The walker should find Analysis, the sub-trace, then follow from
|
||||
Synthesis to discover Observation.
|
||||
"""
|
||||
# Entity triples (s, p, o)
|
||||
entity_data = {
|
||||
"urn:agent:q": [
|
||||
("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION),
|
||||
("urn:agent:q", TG_QUERY, "test"),
|
||||
],
|
||||
"urn:agent:analysis": [
|
||||
("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"),
|
||||
],
|
||||
"urn:graphrag:q": [
|
||||
("urn:graphrag:q", RDF_TYPE, TG_QUESTION),
|
||||
("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
|
||||
("urn:graphrag:q", TG_QUERY, "test"),
|
||||
("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"),
|
||||
],
|
||||
"urn:graphrag:synth": [
|
||||
("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"),
|
||||
],
|
||||
"urn:agent:obs": [
|
||||
("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"),
|
||||
],
|
||||
"urn:agent:conclusion": [
|
||||
("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"),
|
||||
],
|
||||
}
|
||||
|
||||
# Build a mock flow that answers triples queries
|
||||
# Query by s= returns that entity's triples
|
||||
# Query by p=wasDerivedFrom, o=X returns entities derived from X
|
||||
def mock_triples_query(s=None, p=None, o=None, **kwargs):
|
||||
if s and not p:
|
||||
# Fetch entity triples
|
||||
tuples = entity_data.get(s, [])
|
||||
return _make_wire_triples(tuples)
|
||||
elif p == PROV_WAS_DERIVED_FROM and o:
|
||||
# Find entities derived from o
|
||||
results = []
|
||||
for uri, tuples in entity_data.items():
|
||||
for _, pred, obj in tuples:
|
||||
if pred == PROV_WAS_DERIVED_FROM and obj == o:
|
||||
results.append((uri, pred, obj))
|
||||
return _make_wire_triples(results)
|
||||
return []
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.triples_query.side_effect = mock_triples_query
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
|
||||
|
||||
# Mock fetch_graphrag_trace to return a trace with a synthesis
|
||||
synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis")
|
||||
client.fetch_graphrag_trace = MagicMock(return_value={
|
||||
"question": Question(uri="urn:graphrag:q", entity_type="question",
|
||||
question_type="graph-rag"),
|
||||
"synthesis": synth_entity,
|
||||
})
|
||||
|
||||
trace = client.fetch_agent_trace(
|
||||
"urn:agent:q",
|
||||
graph="urn:graph:retrieval",
|
||||
)
|
||||
|
||||
# Should have found all steps
|
||||
step_types = [
|
||||
type(s).__name__ if not isinstance(s, dict) else s.get("type")
|
||||
for s in trace["steps"]
|
||||
]
|
||||
|
||||
assert "Analysis" in step_types, f"Missing Analysis in {step_types}"
|
||||
assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}"
|
||||
assert "Observation" in step_types, f"Missing Observation in {step_types}"
|
||||
assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}"
|
||||
|
||||
# Observation should come after the sub-trace
|
||||
subtrace_idx = step_types.index("sub-trace")
|
||||
obs_idx = step_types.index("Observation")
|
||||
assert obs_idx > subtrace_idx, "Observation should appear after sub-trace"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue