Additional agent DAG tests (#750)

- test_agent_provenance.py: test_session_parent_uri,
  test_session_no_parent_uri, and 6 synthesis tests (types,
  single/multiple parents, document, label)
- test_on_action_callback.py: 3 tests — fires before tool, skipped
  for Final, works when None
- test_callback_message_id.py: 7 tests — message_id on think/observe/
  answer callbacks (streaming + non-streaming) and
  send_final_response
- test_parse_chunk_message_id.py (5 tests) - _parse_chunk propagates
  message_id for thought, observation, answer; handles missing
  gracefully
- test_explainability_parsing.py (+1) -
  test_dispatches_analysis_with_tooluse - Analysis+ToolUse mixin still
  dispatches to Analysis
- test_explainability.py (+1) -
  test_observation_found_via_subtrace_synthesis
- chain walker follows from sub-trace Synthesis to find Observation
  and
  Conclusion in correct order
- test_agent_provenance.py (+8) - session parent_uri (2), synthesis
  single/multiple parents, types, document, label (6)
This commit is contained in:
cybermaggedon 2026-04-01 13:59:34 +01:00 committed by Cyber MacGeddon
parent 3ba6a3238f
commit dbf8daa74a
7 changed files with 733 additions and 1 deletions

View file

@ -0,0 +1,227 @@
#!/usr/bin/env python3
"""
Load test triples into the triple store for testing tg-query-graph.
Tests all graph features:
- SPO with IRI objects
- SPO with literal objects
- Literals with XML datatypes
- Literals with language tags
- Quoted triples (RDF-star)
- Named graphs
"""
import asyncio
import json
import os
import websockets
# Configuration
API_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/")
TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None)
FLOW = "default"
USER = "trustgraph"
COLLECTION = "default"
DOCUMENT_ID = "test-triples-001"
# Namespaces
EX = "http://example.org/"
RDF = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
RDFS = "http://www.w3.org/2000/01/rdf-schema#"
XSD = "http://www.w3.org/2001/XMLSchema#"
TG = "https://trustgraph.ai/ns/"
def iri(value):
"""Build IRI term."""
return {"t": "i", "i": value}
def literal(value, datatype=None, language=None):
"""Build literal term with optional datatype or language."""
term = {"t": "l", "v": value}
if datatype:
term["dt"] = datatype
if language:
term["ln"] = language
return term
def quoted_triple(s, p, o):
"""Build quoted triple term (RDF-star)."""
return {
"t": "t",
"tr": {"s": s, "p": p, "o": o}
}
def triple(s, p, o, g=None):
"""Build a complete triple dict."""
t = {"s": s, "p": p, "o": o}
if g:
t["g"] = g
return t
# Test triples covering all features
TEST_TRIPLES = [
# 1. Basic SPO with IRI object
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDF}type"),
iri(f"{EX}Scientist")
),
# 2. SPO with IRI object (relationship)
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}discovered"),
iri(f"{EX}radium")
),
# 3. Simple literal (no datatype/language)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie")
),
# 4. Literal with language tag (English)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie", language="en")
),
# 5. Literal with language tag (French)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie", language="fr")
),
# 6. Literal with language tag (Polish)
triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Maria Sk\u0142odowska-Curie", language="pl")
),
# 7. Literal with xsd:integer datatype
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}birthYear"),
literal("1867", datatype=f"{XSD}integer")
),
# 8. Literal with xsd:date datatype
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}birthDate"),
literal("1867-11-07", datatype=f"{XSD}date")
),
# 9. Literal with xsd:boolean datatype
triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}nobelLaureate"),
literal("true", datatype=f"{XSD}boolean")
),
# 10. Quoted triple in object position (RDF 1.2 style)
# "Wikipedia asserts that Marie Curie discovered radium"
triple(
iri(f"{EX}wikipedia"),
iri(f"{TG}asserts"),
quoted_triple(
iri(f"{EX}marie-curie"),
iri(f"{EX}discovered"),
iri(f"{EX}radium")
)
),
# 11. Quoted triple with literal inside (object position)
# "NLP-v1.0 extracted that Marie Curie has label Marie Curie"
triple(
iri(f"{EX}nlp-v1"),
iri(f"{TG}extracted"),
quoted_triple(
iri(f"{EX}marie-curie"),
iri(f"{RDFS}label"),
literal("Marie Curie")
)
),
# 12. Triple in a named graph (g is plain string, not Term)
triple(
iri(f"{EX}radium"),
iri(f"{RDF}type"),
iri(f"{EX}Element"),
g=f"{EX}chemistry-graph"
),
# 13. Another triple in the same named graph
triple(
iri(f"{EX}radium"),
iri(f"{EX}atomicNumber"),
literal("88", datatype=f"{XSD}integer"),
g=f"{EX}chemistry-graph"
),
# 14. Triple in a different named graph
triple(
iri(f"{EX}pierre-curie"),
iri(f"{EX}spouseOf"),
iri(f"{EX}marie-curie"),
g=f"{EX}biography-graph"
),
]
async def load_triples():
"""Load test triples via WebSocket bulk import."""
# Convert HTTP URL to WebSocket URL
ws_url = API_URL.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url.rstrip('/')}/api/v1/flow/{FLOW}/import/triples"
if TOKEN:
ws_url = f"{ws_url}?token={TOKEN}"
metadata = {
"id": DOCUMENT_ID,
"metadata": [],
"user": USER,
"collection": COLLECTION
}
print(f"Connecting to {ws_url}...")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=60) as websocket:
message = {
"metadata": metadata,
"triples": TEST_TRIPLES
}
print(f"Sending {len(TEST_TRIPLES)} test triples...")
await websocket.send(json.dumps(message))
print("Triples sent successfully!")
print("\nTest triples loaded:")
print(" - 2 basic IRI triples (type, relationship)")
print(" - 4 literal triples (plain + 3 languages: en, fr, pl)")
print(" - 3 typed literal triples (xsd:integer, xsd:date, xsd:boolean)")
print(" - 2 quoted triples (RDF-star provenance)")
print(" - 3 triples in named graphs (chemistry-graph, biography-graph)")
print(f"\nTotal: {len(TEST_TRIPLES)} triples")
print(f"User: {USER}, Collection: {COLLECTION}")
def main():
print("Loading test triples for tg-query-graph testing\n")
asyncio.run(load_triples())
print("\nDone! Now test with:")
print(" tg-query-graph -s http://example.org/marie-curie")
print(" tg-query-graph -p http://www.w3.org/2000/01/rdf-schema#label")
print(" tg-query-graph -o 'Marie Curie' --object-language en")
print(" tg-query-graph --format json | jq .")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,122 @@
"""
Tests that streaming callbacks set message_id on AgentResponse.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.pattern_base import PatternBase
from trustgraph.schema import AgentResponse
@pytest.fixture
def pattern():
processor = MagicMock()
return PatternBase(processor)
class TestThinkCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_think_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/thought"
think = pattern.make_think_callback(capture, streaming=True, message_id=msg_id)
await think("hello", is_final=False)
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "thought"
@pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/thought"
think = pattern.make_think_callback(capture, streaming=False, message_id=msg_id)
await think("hello")
assert responses[0].message_id == msg_id
assert responses[0].end_of_message is True
class TestObserveCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_observe_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/i1/observation"
observe = pattern.make_observe_callback(capture, streaming=True, message_id=msg_id)
await observe("result", is_final=True)
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "observation"
class TestAnswerCallbackMessageId:
@pytest.mark.asyncio
async def test_streaming_answer_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
answer = pattern.make_answer_callback(capture, streaming=True, message_id=msg_id)
await answer("the answer")
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "answer"
@pytest.mark.asyncio
async def test_no_message_id_default(self, pattern):
responses = []
async def capture(r):
responses.append(r)
answer = pattern.make_answer_callback(capture, streaming=True)
await answer("the answer")
assert responses[0].message_id == ""
class TestSendFinalResponseMessageId:
@pytest.mark.asyncio
async def test_streaming_final_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
await pattern.send_final_response(
capture, streaming=True, answer_text="answer",
message_id=msg_id,
)
# Should get content chunk + end-of-dialog marker
assert all(r.message_id == msg_id for r in responses)
@pytest.mark.asyncio
async def test_non_streaming_final_has_message_id(self, pattern):
responses = []
async def capture(r):
responses.append(r)
msg_id = "urn:trustgraph:agent:sess/final"
await pattern.send_final_response(
capture, streaming=False, answer_text="answer",
message_id=msg_id,
)
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].end_of_dialog is True

View file

@ -22,6 +22,7 @@ from trustgraph.api.explainability import (
TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_OBSERVATION_TYPE,
TG_TOOL_USE,
TG_ANALYSIS,
TG_CONCLUSION,
TG_DOCUMENT,
@ -76,6 +77,13 @@ class TestFromTriplesDispatch:
entity = ExplainEntity.from_triples("urn:a", triples)
assert isinstance(entity, Analysis)
def test_dispatches_analysis_with_tooluse(self):
"""Analysis+ToolUse mixin still dispatches to Analysis."""
triples = _make_triples("urn:a",
[PROV_ENTITY, TG_ANALYSIS, TG_TOOL_USE])
entity = ExplainEntity.from_triples("urn:a", triples)
assert isinstance(entity, Analysis)
def test_dispatches_observation(self):
triples = _make_triples("urn:o", [PROV_ENTITY, TG_OBSERVATION_TYPE])
entity = ExplainEntity.from_triples("urn:o", triples)

View file

@ -0,0 +1,132 @@
"""
Tests for the on_action callback in react() verifies that it fires
after action selection but before tool execution.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.types import Action, Final, Tool, Argument
class TestOnActionCallback:
@pytest.mark.asyncio
async def test_on_action_called_for_tool_use(self):
"""on_action fires when react() selects a tool (not Final)."""
call_log = []
async def fake_on_action(act):
call_log.append(("on_action", act.name))
# Tool that records when it's invoked
async def tool_invoke(**kwargs):
call_log.append(("tool_invoke",))
return "tool result"
tool_impl = MagicMock()
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
tools = {
"search": Tool(
name="search",
description="Search",
implementation=tool_impl,
arguments=[Argument(name="query", type="string", description="q")],
config={},
),
}
agent = AgentManager(tools=tools)
# Mock reason() to return an Action
action = Action(thought="thinking", name="search", arguments={"query": "test"}, observation="")
agent.reason = AsyncMock(return_value=action)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
on_action=fake_on_action,
)
# on_action should fire before tool_invoke
assert len(call_log) == 2
assert call_log[0] == ("on_action", "search")
assert call_log[1] == ("tool_invoke",)
@pytest.mark.asyncio
async def test_on_action_not_called_for_final(self):
"""on_action does not fire when react() returns Final."""
called = []
async def fake_on_action(act):
called.append(act)
agent = AgentManager(tools={})
agent.reason = AsyncMock(
return_value=Final(thought="done", final="answer")
)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
result = await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
on_action=fake_on_action,
)
assert isinstance(result, Final)
assert len(called) == 0
@pytest.mark.asyncio
async def test_on_action_none_accepted(self):
"""react() works fine when on_action is None (default)."""
async def tool_invoke(**kwargs):
return "result"
tool_impl = MagicMock()
tool_impl.return_value.invoke = AsyncMock(side_effect=tool_invoke)
tools = {
"search": Tool(
name="search",
description="Search",
implementation=tool_impl,
arguments=[],
config={},
),
}
agent = AgentManager(tools=tools)
agent.reason = AsyncMock(
return_value=Action(thought="t", name="search", arguments={}, observation="")
)
think = AsyncMock()
observe = AsyncMock()
context = MagicMock()
result = await agent.react(
question="test",
history=[],
think=think,
observe=observe,
context=context,
# on_action not passed — defaults to None
)
assert isinstance(result, Action)
assert result.observation == "result"

View file

@ -0,0 +1,74 @@
"""
Tests that _parse_chunk propagates message_id from wire format
to AgentThought, AgentObservation, and AgentAnswer.
"""
import pytest
from trustgraph.api.socket_client import SocketClient
from trustgraph.api.types import AgentThought, AgentObservation, AgentAnswer
@pytest.fixture
def client():
# We only need _parse_chunk — don't connect
c = object.__new__(SocketClient)
return c
class TestParseChunkMessageId:
def test_thought_message_id(self, client):
resp = {
"chunk_type": "thought",
"content": "thinking...",
"end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentThought)
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/thought"
def test_observation_message_id(self, client):
resp = {
"chunk_type": "observation",
"content": "result",
"end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentObservation)
assert chunk.message_id == "urn:trustgraph:agent:sess/i1/observation"
def test_answer_message_id(self, client):
resp = {
"chunk_type": "answer",
"content": "the answer",
"end_of_message": False,
"end_of_dialog": False,
"message_id": "urn:trustgraph:agent:sess/final",
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentAnswer)
assert chunk.message_id == "urn:trustgraph:agent:sess/final"
def test_thought_missing_message_id(self, client):
resp = {
"chunk_type": "thought",
"content": "thinking...",
"end_of_message": False,
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentThought)
assert chunk.message_id == ""
def test_answer_missing_message_id(self, client):
resp = {
"chunk_type": "answer",
"content": "answer",
"end_of_message": True,
"end_of_dialog": True,
}
chunk = client._parse_chunk(resp)
assert isinstance(chunk, AgentAnswer)
assert chunk.message_id == ""

View file

@ -12,6 +12,7 @@ from trustgraph.provenance.agent import (
agent_iteration_triples,
agent_observation_triples,
agent_final_triples,
agent_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
@ -21,7 +22,7 @@ from trustgraph.provenance.namespaces import (
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_TOOL_USE,
TG_TOOL_USE, TG_SYNTHESIS,
TG_AGENT_QUESTION,
)
@ -105,6 +106,25 @@ class TestAgentSessionTriples:
)
assert len(triples) == 6
def test_session_parent_uri(self):
"""Subagent sessions derive from a parent entity (e.g. Decomposition)."""
parent = "urn:trustgraph:agent:parent/decompose"
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z",
parent_uri=parent,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
assert derived is not None
assert derived.o.iri == parent
def test_session_no_parent_uri(self):
"""Top-level sessions have no wasDerivedFrom."""
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SESSION_URI)
assert derived is None
# ---------------------------------------------------------------------------
# agent_iteration_triples
@ -358,3 +378,59 @@ class TestAgentFinalTriples:
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is None
# ---------------------------------------------------------------------------
# agent_synthesis_triples
# ---------------------------------------------------------------------------
class TestAgentSynthesisTriples:
SYNTH_URI = "urn:trustgraph:agent:test-session/synthesis"
FINDING_0 = "urn:trustgraph:agent:test-session/finding/0"
FINDING_1 = "urn:trustgraph:agent:test-session/finding/1"
FINDING_2 = "urn:trustgraph:agent:test-session/finding/2"
def test_synthesis_types(self):
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
assert has_type(triples, self.SYNTH_URI, PROV_ENTITY)
assert has_type(triples, self.SYNTH_URI, TG_SYNTHESIS)
assert has_type(triples, self.SYNTH_URI, TG_ANSWER_TYPE)
def test_synthesis_single_parent_string(self):
"""Single parent passed as string."""
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 1
assert derived[0].o.iri == self.FINDING_0
def test_synthesis_multiple_parents(self):
"""Multiple parents for supervisor fan-in."""
parents = [self.FINDING_0, self.FINDING_1, self.FINDING_2]
triples = agent_synthesis_triples(self.SYNTH_URI, parents)
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 3
derived_uris = {t.o.iri for t in derived}
assert derived_uris == set(parents)
def test_synthesis_single_parent_as_list(self):
"""Single parent passed as list."""
triples = agent_synthesis_triples(self.SYNTH_URI, [self.FINDING_0])
derived = find_triples(triples, PROV_WAS_DERIVED_FROM, self.SYNTH_URI)
assert len(derived) == 1
assert derived[0].o.iri == self.FINDING_0
def test_synthesis_document(self):
triples = agent_synthesis_triples(
self.SYNTH_URI, self.FINDING_0,
document_id="urn:doc:synth",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYNTH_URI)
assert doc is not None
assert doc.o.iri == "urn:doc:synth"
def test_synthesis_label(self):
triples = agent_synthesis_triples(self.SYNTH_URI, self.FINDING_0)
label = find_triple(triples, RDFS_LABEL, self.SYNTH_URI)
assert label is not None
assert label.o.value == "Synthesis"

View file

@ -558,3 +558,96 @@ class TestExplainabilityClientDetectSessionType:
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
class TestChainWalkerFollowsSubTraceTerminal:
"""Test that _follow_provenance_chain continues from a sub-trace's
Synthesis to find downstream entities like Observation."""
def test_observation_found_via_subtrace_synthesis(self):
"""
DAG: Question -> Analysis -> GraphRAG Question -> Synthesis -> Observation
The walker should find Analysis, the sub-trace, then follow from
Synthesis to discover Observation.
"""
# Entity triples (s, p, o)
entity_data = {
"urn:agent:q": [
("urn:agent:q", RDF_TYPE, TG_AGENT_QUESTION),
("urn:agent:q", TG_QUERY, "test"),
],
"urn:agent:analysis": [
("urn:agent:analysis", RDF_TYPE, TG_ANALYSIS),
("urn:agent:analysis", PROV_WAS_DERIVED_FROM, "urn:agent:q"),
],
"urn:graphrag:q": [
("urn:graphrag:q", RDF_TYPE, TG_QUESTION),
("urn:graphrag:q", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:graphrag:q", TG_QUERY, "test"),
("urn:graphrag:q", PROV_WAS_DERIVED_FROM, "urn:agent:analysis"),
],
"urn:graphrag:synth": [
("urn:graphrag:synth", RDF_TYPE, TG_SYNTHESIS),
("urn:graphrag:synth", PROV_WAS_DERIVED_FROM, "urn:graphrag:q"),
],
"urn:agent:obs": [
("urn:agent:obs", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:agent:obs", PROV_WAS_DERIVED_FROM, "urn:graphrag:synth"),
],
"urn:agent:conclusion": [
("urn:agent:conclusion", RDF_TYPE, TG_CONCLUSION),
("urn:agent:conclusion", PROV_WAS_DERIVED_FROM, "urn:agent:obs"),
],
}
# Build a mock flow that answers triples queries
# Query by s= returns that entity's triples
# Query by p=wasDerivedFrom, o=X returns entities derived from X
def mock_triples_query(s=None, p=None, o=None, **kwargs):
if s and not p:
# Fetch entity triples
tuples = entity_data.get(s, [])
return _make_wire_triples(tuples)
elif p == PROV_WAS_DERIVED_FROM and o:
# Find entities derived from o
results = []
for uri, tuples in entity_data.items():
for _, pred, obj in tuples:
if pred == PROV_WAS_DERIVED_FROM and obj == o:
results.append((uri, pred, obj))
return _make_wire_triples(results)
return []
mock_flow = MagicMock()
mock_flow.triples_query.side_effect = mock_triples_query
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
# Mock fetch_graphrag_trace to return a trace with a synthesis
synth_entity = Synthesis(uri="urn:graphrag:synth", entity_type="synthesis")
client.fetch_graphrag_trace = MagicMock(return_value={
"question": Question(uri="urn:graphrag:q", entity_type="question",
question_type="graph-rag"),
"synthesis": synth_entity,
})
trace = client.fetch_agent_trace(
"urn:agent:q",
graph="urn:graph:retrieval",
)
# Should have found all steps
step_types = [
type(s).__name__ if not isinstance(s, dict) else s.get("type")
for s in trace["steps"]
]
assert "Analysis" in step_types, f"Missing Analysis in {step_types}"
assert "sub-trace" in step_types, f"Missing sub-trace in {step_types}"
assert "Observation" in step_types, f"Missing Observation in {step_types}"
assert "Conclusion" in step_types, f"Missing Conclusion in {step_types}"
# Observation should come after the sub-trace
subtrace_idx = step_types.index("sub-trace")
obs_idx = step_types.index("Observation")
assert obs_idx > subtrace_idx, "Observation should appear after sub-trace"