Add agent explainability instrumentation and unify envelope field naming (#795)

Addresses recommendations from the UX developer's agent experience report.
Adds provenance predicates, DAG structure changes, error resilience, and
a published OWL ontology.

Explainability additions:

- Tool candidates: tg:toolCandidate on Analysis events lists the tools
  visible to the LLM for each iteration (names only, descriptions in config)
- Termination reason: tg:terminationReason on Conclusion/Synthesis events
  (final-answer, plan-complete, subagents-complete)
- Step counter: tg:stepNumber on iteration events
- Pattern decision: new tg:PatternDecision entity in the DAG between
  session and first iteration, carrying tg:pattern and tg:taskType
- Latency: tg:llmDurationMs on Analysis events, tg:toolDurationMs on
  Observation events
- Token counts on events: tg:inToken/tg:outToken/tg:llmModel on
  Grounding, Focus, Synthesis, and Analysis events
- Tool/parse errors: tg:toolError on Observation events with tg:Error
  mixin type. Parse failures return as error observations instead of
  crashing the agent, giving it a chance to retry.

Envelope unification:

- Rename chunk_type to message_type across AgentResponse schema,
  translator, SDK types, socket clients, CLI, and all tests.
  Agent and RAG services now both use message_type on the wire.

Ontology:

- specs/ontology/trustgraph.ttl — OWL vocabulary covering all 26 classes,
  7 object properties, and 36+ datatype properties including new predicates.

DAG structure tests:

- tests/unit/test_provenance/test_dag_structure.py verifies the
  wasDerivedFrom chain for GraphRAG, DocumentRAG, and all three agent
  patterns (react, plan, supervisor) including the pattern-decision link.
This commit is contained in:
cybermaggedon 2026-04-13 16:16:42 +01:00 committed by GitHub
parent 14e49d83c7
commit d2751553a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 1577 additions and 205 deletions

View file

@ -131,21 +131,21 @@ async def analyse(path, url, flow, user, collection):
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
resp = msg.get("response", {}) resp = msg.get("response", {})
chunk_type = resp.get("chunk_type", "?") message_type = resp.get("message_type", "?")
if chunk_type == "explain": if message_type == "explain":
explain_id = resp.get("explain_id", "") explain_id = resp.get("explain_id", "")
explain_ids.append(explain_id) explain_ids.append(explain_id)
print(f" {i:3d} {chunk_type} {explain_id}") print(f" {i:3d} {message_type} {explain_id}")
else: else:
print(f" {i:3d} {chunk_type}") print(f" {i:3d} {message_type}")
# Rule 7: message_id on content chunks # Rule 7: message_id on content chunks
if chunk_type in ("thought", "observation", "answer"): if message_type in ("thought", "observation", "answer"):
mid = resp.get("message_id", "") mid = resp.get("message_id", "")
if not mid: if not mid:
errors.append( errors.append(
f"[msg {i}] {chunk_type} chunk missing message_id" f"[msg {i}] {message_type} chunk missing message_id"
) )
print() print()

View file

@ -0,0 +1,415 @@
@prefix tg: <https://trustgraph.ai/ns/> .
@prefix owl: <http://www.w3.org/2002/07/owl#> .
@prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
@prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
@prefix xsd: <http://www.w3.org/2001/XMLSchema#> .
@prefix prov: <http://www.w3.org/ns/prov#> .
# =============================================================================
# Ontology declaration
# =============================================================================
<https://trustgraph.ai/ns/>
a owl:Ontology ;
rdfs:label "TrustGraph Ontology" ;
rdfs:comment "Vocabulary for TrustGraph provenance, extraction metadata, and explainability." ;
owl:versionInfo "2.3" .
# =============================================================================
# Classes — Extraction provenance
# =============================================================================
tg:Document a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Document" ;
rdfs:comment "A loaded document (PDF, text, etc.)." .
tg:Page a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Page" ;
rdfs:comment "A page within a document." .
tg:Section a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Section" ;
rdfs:comment "A structural section within a document." .
tg:Chunk a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Chunk" ;
rdfs:comment "A text chunk produced by the chunker." .
tg:Image a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Image" ;
rdfs:comment "An image extracted from a document." .
tg:Subgraph a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Subgraph" ;
rdfs:comment "A set of triples extracted from a chunk." .
# =============================================================================
# Classes — Query-time explainability (shared)
# =============================================================================
tg:Question a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Question" ;
rdfs:comment "Root entity for a query session." .
tg:GraphRagQuestion a owl:Class ;
rdfs:subClassOf tg:Question ;
rdfs:label "Graph RAG Question" ;
rdfs:comment "A question answered via graph-based RAG." .
tg:DocRagQuestion a owl:Class ;
rdfs:subClassOf tg:Question ;
rdfs:label "Document RAG Question" ;
rdfs:comment "A question answered via document-based RAG." .
tg:AgentQuestion a owl:Class ;
rdfs:subClassOf tg:Question ;
rdfs:label "Agent Question" ;
rdfs:comment "A question answered via the agent orchestrator." .
tg:Grounding a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Grounding" ;
rdfs:comment "Concept extraction step (query decomposition into search terms)." .
tg:Exploration a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Exploration" ;
rdfs:comment "Entity/chunk retrieval step." .
tg:Focus a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Focus" ;
rdfs:comment "Edge selection and scoring step (GraphRAG)." .
tg:Synthesis a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Synthesis" ;
rdfs:comment "Final answer synthesis from retrieved context." .
# =============================================================================
# Classes — Agent provenance
# =============================================================================
tg:Analysis a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Analysis" ;
rdfs:comment "One agent iteration: reasoning followed by tool selection." .
tg:ToolUse a owl:Class ;
rdfs:label "ToolUse" ;
rdfs:comment "Mixin type applied to Analysis when a tool is invoked." .
tg:Error a owl:Class ;
rdfs:label "Error" ;
rdfs:comment "Mixin type applied to events where a failure occurred (tool error, parse error)." .
tg:Conclusion a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Conclusion" ;
rdfs:comment "Agent final answer (ReAct pattern)." .
tg:PatternDecision a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Pattern Decision" ;
rdfs:comment "Meta-router decision recording which execution pattern was selected." .
# --- Unifying types ---
tg:Answer a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Answer" ;
rdfs:comment "Unifying type for any terminal answer (Synthesis, Conclusion, Finding, StepResult)." .
tg:Reflection a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Reflection" ;
rdfs:comment "Unifying type for intermediate commentary (Thought, Observation)." .
tg:Thought a owl:Class ;
rdfs:subClassOf tg:Reflection ;
rdfs:label "Thought" ;
rdfs:comment "Agent reasoning text within an iteration." .
tg:Observation a owl:Class ;
rdfs:subClassOf tg:Reflection ;
rdfs:label "Observation" ;
rdfs:comment "Tool execution result." .
# --- Orchestrator types ---
tg:Decomposition a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Decomposition" ;
rdfs:comment "Supervisor pattern: question decomposed into sub-goals." .
tg:Finding a owl:Class ;
rdfs:subClassOf tg:Answer ;
rdfs:label "Finding" ;
rdfs:comment "Result from a sub-agent execution." .
tg:Plan a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Plan" ;
rdfs:comment "Plan-then-execute pattern: structured plan of steps." .
tg:StepResult a owl:Class ;
rdfs:subClassOf tg:Answer ;
rdfs:label "Step Result" ;
rdfs:comment "Result from executing one plan step." .
# =============================================================================
# Properties — Extraction metadata
# =============================================================================
tg:contains a owl:ObjectProperty ;
rdfs:label "contains" ;
rdfs:comment "Links a parent entity to a child (e.g. Document contains Page, Subgraph contains triple)." .
tg:pageCount a owl:DatatypeProperty ;
rdfs:label "page count" ;
rdfs:range xsd:integer ;
rdfs:domain tg:Document .
tg:mimeType a owl:DatatypeProperty ;
rdfs:label "MIME type" ;
rdfs:range xsd:string ;
rdfs:domain tg:Document .
tg:pageNumber a owl:DatatypeProperty ;
rdfs:label "page number" ;
rdfs:range xsd:integer ;
rdfs:domain tg:Page .
tg:chunkIndex a owl:DatatypeProperty ;
rdfs:label "chunk index" ;
rdfs:range xsd:integer ;
rdfs:domain tg:Chunk .
tg:charOffset a owl:DatatypeProperty ;
rdfs:label "character offset" ;
rdfs:range xsd:integer .
tg:charLength a owl:DatatypeProperty ;
rdfs:label "character length" ;
rdfs:range xsd:integer .
tg:chunkSize a owl:DatatypeProperty ;
rdfs:label "chunk size" ;
rdfs:range xsd:integer .
tg:chunkOverlap a owl:DatatypeProperty ;
rdfs:label "chunk overlap" ;
rdfs:range xsd:integer .
tg:componentVersion a owl:DatatypeProperty ;
rdfs:label "component version" ;
rdfs:range xsd:string .
tg:llmModel a owl:DatatypeProperty ;
rdfs:label "LLM model" ;
rdfs:range xsd:string .
tg:ontology a owl:DatatypeProperty ;
rdfs:label "ontology" ;
rdfs:range xsd:string .
tg:embeddingModel a owl:DatatypeProperty ;
rdfs:label "embedding model" ;
rdfs:range xsd:string .
tg:sourceText a owl:DatatypeProperty ;
rdfs:label "source text" ;
rdfs:range xsd:string .
tg:sourceCharOffset a owl:DatatypeProperty ;
rdfs:label "source character offset" ;
rdfs:range xsd:integer .
tg:sourceCharLength a owl:DatatypeProperty ;
rdfs:label "source character length" ;
rdfs:range xsd:integer .
tg:elementTypes a owl:DatatypeProperty ;
rdfs:label "element types" ;
rdfs:range xsd:string .
tg:tableCount a owl:DatatypeProperty ;
rdfs:label "table count" ;
rdfs:range xsd:integer .
tg:imageCount a owl:DatatypeProperty ;
rdfs:label "image count" ;
rdfs:range xsd:integer .
# =============================================================================
# Properties — Query-time provenance (GraphRAG / DocumentRAG)
# =============================================================================
tg:query a owl:DatatypeProperty ;
rdfs:label "query" ;
rdfs:comment "The user's query text." ;
rdfs:range xsd:string ;
rdfs:domain tg:Question .
tg:concept a owl:DatatypeProperty ;
rdfs:label "concept" ;
rdfs:comment "An extracted concept from the query." ;
rdfs:range xsd:string ;
rdfs:domain tg:Grounding .
tg:entity a owl:ObjectProperty ;
rdfs:label "entity" ;
rdfs:comment "A seed entity retrieved during exploration." ;
rdfs:domain tg:Exploration .
tg:edgeCount a owl:DatatypeProperty ;
rdfs:label "edge count" ;
rdfs:comment "Number of edges explored." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Exploration .
tg:selectedEdge a owl:ObjectProperty ;
rdfs:label "selected edge" ;
rdfs:comment "Link to an edge selection entity within a Focus event." ;
rdfs:domain tg:Focus .
tg:edge a owl:ObjectProperty ;
rdfs:label "edge" ;
rdfs:comment "A quoted triple representing a knowledge graph edge." ;
rdfs:domain tg:Focus .
tg:reasoning a owl:DatatypeProperty ;
rdfs:label "reasoning" ;
rdfs:comment "LLM-generated reasoning for an edge selection." ;
rdfs:range xsd:string .
tg:document a owl:ObjectProperty ;
rdfs:label "document" ;
rdfs:comment "Reference to a document stored in the librarian." .
tg:chunkCount a owl:DatatypeProperty ;
rdfs:label "chunk count" ;
rdfs:comment "Number of document chunks retrieved (DocumentRAG)." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Exploration .
tg:selectedChunk a owl:DatatypeProperty ;
rdfs:label "selected chunk" ;
rdfs:comment "A selected chunk ID (DocumentRAG)." ;
rdfs:range xsd:string ;
rdfs:domain tg:Exploration .
# =============================================================================
# Properties — Agent provenance
# =============================================================================
tg:thought a owl:ObjectProperty ;
rdfs:label "thought" ;
rdfs:comment "Links an Analysis iteration to its Thought sub-entity." ;
rdfs:domain tg:Analysis ;
rdfs:range tg:Thought .
tg:action a owl:DatatypeProperty ;
rdfs:label "action" ;
rdfs:comment "The tool/action name selected by the agent." ;
rdfs:range xsd:string ;
rdfs:domain tg:Analysis .
tg:arguments a owl:DatatypeProperty ;
rdfs:label "arguments" ;
rdfs:comment "JSON-encoded arguments passed to the tool." ;
rdfs:range xsd:string ;
rdfs:domain tg:Analysis .
tg:observation a owl:ObjectProperty ;
rdfs:label "observation" ;
rdfs:comment "Links an Analysis iteration to its Observation sub-entity." ;
rdfs:domain tg:Analysis ;
rdfs:range tg:Observation .
tg:toolCandidate a owl:DatatypeProperty ;
rdfs:label "tool candidate" ;
rdfs:comment "Name of a tool available to the LLM for this iteration. One triple per candidate." ;
rdfs:range xsd:string ;
rdfs:domain tg:Analysis .
tg:stepNumber a owl:DatatypeProperty ;
rdfs:label "step number" ;
rdfs:comment "Explicit 1-based step counter for iteration events." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Analysis .
tg:terminationReason a owl:DatatypeProperty ;
rdfs:label "termination reason" ;
rdfs:comment "Why the agent loop stopped: final-answer, plan-complete, subagents-complete, max-iterations, error." ;
rdfs:range xsd:string .
tg:pattern a owl:DatatypeProperty ;
rdfs:label "pattern" ;
rdfs:comment "Selected execution pattern (react, plan-then-execute, supervisor)." ;
rdfs:range xsd:string ;
rdfs:domain tg:PatternDecision .
tg:taskType a owl:DatatypeProperty ;
rdfs:label "task type" ;
rdfs:comment "Identified task type from the meta-router (general, research, etc.)." ;
rdfs:range xsd:string ;
rdfs:domain tg:PatternDecision .
tg:llmDurationMs a owl:DatatypeProperty ;
rdfs:label "LLM duration (ms)" ;
rdfs:comment "Time spent in the LLM prompt call, in milliseconds." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Analysis .
tg:toolDurationMs a owl:DatatypeProperty ;
rdfs:label "tool duration (ms)" ;
rdfs:comment "Time spent executing the tool, in milliseconds." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Observation .
tg:toolError a owl:DatatypeProperty ;
rdfs:label "tool error" ;
rdfs:comment "Error message from a failed tool execution." ;
rdfs:range xsd:string ;
rdfs:domain tg:Observation .
# --- Token usage predicates (on any event that involves an LLM call) ---
tg:inToken a owl:DatatypeProperty ;
rdfs:label "input tokens" ;
rdfs:comment "Input token count for the LLM call associated with this event." ;
rdfs:range xsd:integer .
tg:outToken a owl:DatatypeProperty ;
rdfs:label "output tokens" ;
rdfs:comment "Output token count for the LLM call associated with this event." ;
rdfs:range xsd:integer .
# --- Orchestrator predicates ---
tg:subagentGoal a owl:DatatypeProperty ;
rdfs:label "sub-agent goal" ;
rdfs:comment "Goal string assigned to a sub-agent (Decomposition, Finding)." ;
rdfs:range xsd:string .
tg:planStep a owl:DatatypeProperty ;
rdfs:label "plan step" ;
rdfs:comment "Goal string for a plan step (Plan, StepResult)." ;
rdfs:range xsd:string .
# =============================================================================
# Named graphs
# =============================================================================
# These are not OWL classes but documented here for reference:
#
# (default graph) — Core knowledge facts (extracted triples)
# urn:graph:source — Extraction provenance (document → chunk → triple)
# urn:graph:retrieval — Query-time explainability (question → exploration → synthesis)

View file

@ -87,7 +87,7 @@ def sample_message_data():
"history": [] "history": []
}, },
"AgentResponse": { "AgentResponse": {
"chunk_type": "answer", "message_type": "answer",
"content": "Machine learning is a subset of AI.", "content": "Machine learning is a subset of AI.",
"end_of_message": True, "end_of_message": True,
"end_of_dialog": True, "end_of_dialog": True,

View file

@ -212,7 +212,7 @@ class TestAgentMessageContracts:
# Test required fields # Test required fields
response = AgentResponse(**response_data) response = AgentResponse(**response_data)
assert hasattr(response, 'chunk_type') assert hasattr(response, 'message_type')
assert hasattr(response, 'content') assert hasattr(response, 'content')
assert hasattr(response, 'end_of_message') assert hasattr(response, 'end_of_message')
assert hasattr(response, 'end_of_dialog') assert hasattr(response, 'end_of_dialog')

View file

@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange # Arrange
translator = TranslatorRegistry.get_response_translator("agent") translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse( response = AgentResponse(
chunk_type="answer", message_type="answer",
content="4", content="4",
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,
@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange # Arrange
translator = TranslatorRegistry.get_response_translator("agent") translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse( response = AgentResponse(
chunk_type="thought", message_type="thought",
content="I need to solve this.", content="I need to solve this.",
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags:
# Test thought message # Test thought message
thought_response = AgentResponse( thought_response = AgentResponse(
chunk_type="thought", message_type="thought",
content="Processing...", content="Processing...",
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags:
# Test observation message # Test observation message
observation_response = AgentResponse( observation_response = AgentResponse(
chunk_type="observation", message_type="observation",
content="Result found", content="Result found",
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags:
# Streaming format with end_of_dialog=True # Streaming format with end_of_dialog=True
response = AgentResponse( response = AgentResponse(
chunk_type="answer", message_type="answer",
content="", content="",
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,

View file

@ -418,55 +418,55 @@ def sample_streaming_agent_response():
"""Sample streaming agent response chunks""" """Sample streaming agent response chunks"""
return [ return [
{ {
"chunk_type": "thought", "message_type": "thought",
"content": "I need to search", "content": "I need to search",
"end_of_message": False, "end_of_message": False,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "thought", "message_type": "thought",
"content": " for information", "content": " for information",
"end_of_message": False, "end_of_message": False,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "thought", "message_type": "thought",
"content": " about machine learning.", "content": " about machine learning.",
"end_of_message": True, "end_of_message": True,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "action", "message_type": "action",
"content": "knowledge_query", "content": "knowledge_query",
"end_of_message": True, "end_of_message": True,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "observation", "message_type": "observation",
"content": "Machine learning is", "content": "Machine learning is",
"end_of_message": False, "end_of_message": False,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "observation", "message_type": "observation",
"content": " a subset of AI.", "content": " a subset of AI.",
"end_of_message": True, "end_of_message": True,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "final-answer", "message_type": "final-answer",
"content": "Machine learning", "content": "Machine learning",
"end_of_message": False, "end_of_message": False,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "final-answer", "message_type": "final-answer",
"content": " is a subset", "content": " is a subset",
"end_of_message": False, "end_of_message": False,
"end_of_dialog": False "end_of_dialog": False
}, },
{ {
"chunk_type": "final-answer", "message_type": "final-answer",
"content": " of artificial intelligence.", "content": " of artificial intelligence.",
"end_of_message": True, "end_of_message": True,
"end_of_dialog": True "end_of_dialog": True
@ -494,10 +494,10 @@ def streaming_chunk_collector():
"""Concatenate all chunk content""" """Concatenate all chunk content"""
return "".join(self.chunks) return "".join(self.chunks)
def get_chunk_types(self): def get_message_types(self):
"""Get list of chunk types if chunks are dicts""" """Get list of chunk types if chunks are dicts"""
if self.chunks and isinstance(self.chunks[0], dict): if self.chunks and isinstance(self.chunks[0], dict):
return [c.get("chunk_type") for c in self.chunks] return [c.get("message_type") for c in self.chunks]
return [] return []
def verify_streaming_protocol(self): def verify_streaming_protocol(self):

View file

@ -327,11 +327,13 @@ Args: {
think_callback = AsyncMock() think_callback = AsyncMock()
observe_callback = AsyncMock() observe_callback = AsyncMock()
# Act & Assert # Act - tool errors are now caught and returned as observations
with pytest.raises(Exception) as exc_info: result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
# Assert - error captured on the action, not raised
assert "Tool execution failed" in str(exc_info.value) assert result.tool_error is not None
assert "Tool execution failed" in result.tool_error
assert "Error:" in result.observation
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context): async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
@ -538,12 +540,11 @@ Args: {
) )
if test_case["error_contains"]: if test_case["error_contains"]:
# Should raise an error # Parse errors now return an Action with tool_error
with pytest.raises(RuntimeError) as exc_info: result = await agent_manager.reason("test question", [], mock_flow_context)
await agent_manager.reason("test question", [], mock_flow_context) assert isinstance(result, Action)
assert result.name == "__parse_error__"
assert "Failed to parse agent response" in str(exc_info.value) assert result.tool_error is not None
assert test_case["error_contains"] in str(exc_info.value)
else: else:
# Should succeed # Should succeed
action = await agent_manager.reason("test question", [], mock_flow_context) action = await agent_manager.reason("test question", [], mock_flow_context)

View file

@ -15,7 +15,7 @@ from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks, assert_agent_streaming_chunks,
assert_streaming_chunks_valid, assert_streaming_chunks_valid,
assert_callback_invoked, assert_callback_invoked,
assert_chunk_types_valid, assert_message_types_valid,
) )

View file

@ -78,10 +78,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now # Filter out explain events — those are always sent now
content_responses = [ content_responses = [
r for r in sent_responses if r.chunk_type != "explain" r for r in sent_responses if r.message_type != "explain"
] ]
explain_responses = [ explain_responses = [
r for r in sent_responses if r.chunk_type == "explain" r for r in sent_responses if r.message_type == "explain"
] ]
# Should have explain events for session, iteration, observation, and final # Should have explain events for session, iteration, observation, and final
@ -93,7 +93,7 @@ class TestAgentServiceNonStreaming:
# Check thought message # Check thought message
thought_response = content_responses[0] thought_response = content_responses[0]
assert isinstance(thought_response, AgentResponse) assert isinstance(thought_response, AgentResponse)
assert thought_response.chunk_type == "thought" assert thought_response.message_type == "thought"
assert thought_response.content == "I need to solve this." assert thought_response.content == "I need to solve this."
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True" assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False" assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
@ -101,7 +101,7 @@ class TestAgentServiceNonStreaming:
# Check observation message # Check observation message
observation_response = content_responses[1] observation_response = content_responses[1]
assert isinstance(observation_response, AgentResponse) assert isinstance(observation_response, AgentResponse)
assert observation_response.chunk_type == "observation" assert observation_response.message_type == "observation"
assert observation_response.content == "The answer is 4." assert observation_response.content == "The answer is 4."
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True" assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False" assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@ -168,10 +168,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now # Filter out explain events — those are always sent now
content_responses = [ content_responses = [
r for r in sent_responses if r.chunk_type != "explain" r for r in sent_responses if r.message_type != "explain"
] ]
explain_responses = [ explain_responses = [
r for r in sent_responses if r.chunk_type == "explain" r for r in sent_responses if r.message_type == "explain"
] ]
# Should have explain events for session and final # Should have explain events for session and final
@ -183,7 +183,7 @@ class TestAgentServiceNonStreaming:
# Check final answer message # Check final answer message
answer_response = content_responses[0] answer_response = content_responses[0]
assert isinstance(answer_response, AgentResponse) assert isinstance(answer_response, AgentResponse)
assert answer_response.chunk_type == "answer" assert answer_response.message_type == "answer"
assert answer_response.content == "4" assert answer_response.content == "4"
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True" assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True" assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"

View file

@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
assert len(responses) == 1 assert len(responses) == 1
assert responses[0].message_id == msg_id assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "thought" assert responses[0].message_type == "thought"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern): async def test_non_streaming_think_has_message_id(self, pattern):
@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
await observe("result", is_final=True) await observe("result", is_final=True)
assert responses[0].message_id == msg_id assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "observation" assert responses[0].message_type == "observation"
class TestAnswerCallbackMessageId: class TestAnswerCallbackMessageId:
@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
await answer("the answer") await answer("the answer")
assert responses[0].message_id == msg_id assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "answer" assert responses[0].message_type == "answer"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_message_id_default(self, pattern): async def test_no_message_id_default(self, pattern):

View file

@ -69,7 +69,7 @@ def collect_explain_events(respond_mock):
events = [] events = []
for call in respond_mock.call_args_list: for call in respond_mock.call_args_list:
resp = call[0][0] resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain": if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({ events.append({
"explain_id": resp.explain_id, "explain_id": resp.explain_id,
"explain_graph": resp.explain_graph, "explain_graph": resp.explain_graph,

View file

@ -20,7 +20,7 @@ class TestParseChunkMessageId:
def test_thought_message_id(self, client): def test_thought_message_id(self, client):
resp = { resp = {
"chunk_type": "thought", "message_type": "thought",
"content": "thinking...", "content": "thinking...",
"end_of_message": False, "end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought", "message_id": "urn:trustgraph:agent:sess/i1/thought",
@ -31,7 +31,7 @@ class TestParseChunkMessageId:
def test_observation_message_id(self, client): def test_observation_message_id(self, client):
resp = { resp = {
"chunk_type": "observation", "message_type": "observation",
"content": "result", "content": "result",
"end_of_message": True, "end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation", "message_id": "urn:trustgraph:agent:sess/i1/observation",
@ -42,7 +42,7 @@ class TestParseChunkMessageId:
def test_answer_message_id(self, client): def test_answer_message_id(self, client):
resp = { resp = {
"chunk_type": "answer", "message_type": "answer",
"content": "the answer", "content": "the answer",
"end_of_message": False, "end_of_message": False,
"end_of_dialog": False, "end_of_dialog": False,
@ -54,7 +54,7 @@ class TestParseChunkMessageId:
def test_thought_missing_message_id(self, client): def test_thought_missing_message_id(self, client):
resp = { resp = {
"chunk_type": "thought", "message_type": "thought",
"content": "thinking...", "content": "thinking...",
"end_of_message": False, "end_of_message": False,
} }
@ -64,7 +64,7 @@ class TestParseChunkMessageId:
def test_answer_missing_message_id(self, client): def test_answer_missing_message_id(self, client):
resp = { resp = {
"chunk_type": "answer", "message_type": "answer",
"content": "answer", "content": "answer",
"end_of_message": True, "end_of_message": True,
"end_of_dialog": True, "end_of_dialog": True,

View file

@ -158,7 +158,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator() translator = AgentResponseTranslator()
response = AgentResponse( response = AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id="urn:trustgraph:agent:session:abc123", explain_id="urn:trustgraph:agent:session:abc123",
explain_graph="urn:graph:retrieval", explain_graph="urn:graph:retrieval",
@ -179,7 +179,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator() translator = AgentResponseTranslator()
response = AgentResponse( response = AgentResponse(
chunk_type="thought", message_type="thought",
content="I need to think...", content="I need to think...",
) )
@ -190,7 +190,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator() translator = AgentResponseTranslator()
response = AgentResponse( response = AgentResponse(
chunk_type="explain", message_type="explain",
explain_id="urn:trustgraph:agent:session:abc123", explain_id="urn:trustgraph:agent:session:abc123",
explain_triples=sample_triples(), explain_triples=sample_triples(),
end_of_dialog=False, end_of_dialog=False,
@ -203,7 +203,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator() translator = AgentResponseTranslator()
response = AgentResponse( response = AgentResponse(
chunk_type="answer", message_type="answer",
content="The answer is...", content="The answer is...",
end_of_dialog=True, end_of_dialog=True,
) )

View file

@ -0,0 +1,592 @@
"""
DAG structure tests for provenance chains.
Verifies that the wasDerivedFrom chain has the expected shape for each
service. These tests catch structural regressions when new entities are
inserted into the chain (e.g. PatternDecision between session and first
iteration).
Expected chains:
GraphRAG: question grounding exploration focus synthesis
DocumentRAG: question grounding exploration synthesis
Agent React: session pattern-decision iteration (observation iteration)* final
Agent Plan: session pattern-decision plan step-result(s) synthesis
Agent Super: session pattern-decision decomposition (fan-out) finding(s) synthesis
"""
import json
import uuid
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
Triple, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
TG_OBSERVATION_TYPE,
TG_PATTERN, TG_TASK_TYPE,
)
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_events(events):
"""Build a dict of explain_id → {types, derived_from, triples}."""
result = {}
for ev in events:
eid = ev["explain_id"]
triples = ev["triples"]
types = {
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == RDF_TYPE
}
parents = [
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
]
result[eid] = {
"types": types,
"derived_from": parents[0] if parents else None,
"triples": triples,
}
return result
def _find_by_type(dag, rdf_type):
"""Find all event IDs that have the given rdf:type."""
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
def _assert_chain(dag, chain_types):
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
for i in range(1, len(chain_types)):
parent_type = chain_types[i - 1]
child_type = chain_types[i]
parents = _find_by_type(dag, parent_type)
children = _find_by_type(dag, child_type)
assert parents, f"No entity with type {parent_type}"
assert children, f"No entity with type {child_type}"
# At least one child must derive from at least one parent
linked = False
for child_id in children:
derived = dag[child_id]["derived_from"]
if derived in parents:
linked = True
break
assert linked, (
f"No {child_type} derives from {parent_type}. "
f"Children derive from: "
f"{[dag[c]['derived_from'] for c in children]}"
)
# ---------------------------------------------------------------------------
# GraphRAG DAG structure
# ---------------------------------------------------------------------------
class TestGraphRagDagStructure:
"""Verify: question → grounding → exploration → focus → synthesis"""
@pytest.fixture
def mock_clients(self):
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
embeddings_client.embed.return_value = [[0.1, 0.2]]
graph_embeddings_client.query.return_value = [
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
]
triples_client.query_stream.return_value = [
Triple(
s=Term(type=IRI, iri="http://example.com/e1"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="value"),
)
]
triples_client.query.return_value = []
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
elif template_id == "kg-edge-scoring":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "score": 10} for e in edges],
)
elif template_id == "kg-edge-reasoning":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
)
elif template_id == "kg-synthesis":
return PromptResult(response_type="text", text="Answer.")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = GraphRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb, edge_score_limit=0,
)
dag = _collect_events(events)
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
_assert_chain(dag, [
TG_GRAPH_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_FOCUS,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# DocumentRAG DAG structure
# ---------------------------------------------------------------------------
class TestDocumentRagDagStructure:
"""Verify: question → grounding → exploration → synthesis"""
@pytest.fixture
def mock_clients(self):
from trustgraph.schema import ChunkMatch
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock(return_value="Chunk content.")
embeddings_client.embed.return_value = [[0.1, 0.2]]
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.9),
]
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
prompt_client.document_prompt.return_value = PromptResult(
response_type="text", text="Answer.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = DocumentRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb,
)
dag = _collect_events(events)
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
_assert_chain(dag, [
TG_DOC_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# Agent DAG structure — tested via service.agent_request()
# ---------------------------------------------------------------------------
def _make_processor(tools=None):
processor = MagicMock()
processor.max_iterations = 10
processor.save_answer_content = AsyncMock()
def mock_session_uri(sid):
return f"urn:trustgraph:agent:session:{sid}"
processor.provenance_session_uri.side_effect = mock_session_uri
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
processor.aggregator = MagicMock()
return processor
def _make_flow():
producers = {}
def factory(name):
if name not in producers:
producers[name] = AsyncMock()
return producers[name]
flow = MagicMock(side_effect=factory)
return flow
def _collect_agent_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"triples": resp.explain_triples,
})
return events
class TestAgentReactDagStructure:
"""
Via service.agent_request(), full two-iteration react chain:
session pattern-decision iteration(1) observation(1) final
Iteration 1: tool call observation
Iteration 2: final answer
"""
def _make_service(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
mock_tool = MagicMock()
mock_tool.name = "lookup"
mock_tool.description = "Look things up"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="42")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"lookup": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
return service
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.react.types import Action, Final
service = self._make_service()
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
session_id = str(uuid.uuid4())
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
action = Action(
thought="I need to look this up",
name="lookup",
arguments={"question": "6x7"},
observation="",
)
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter1(on_action=None, **kwargs):
if on_action:
await on_action(action)
action.observation = "42"
return action
mock_am.react.side_effect = mock_react_iter1
request1 = AgentRequest(
question="What is 6x7?",
user="testuser",
collection="default",
streaming=False,
session_id=session_id,
pattern="react",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# next_fn should have been called with updated history
assert next_fn.called
# Iteration 2: final answer
final = Final(thought="The answer is 42", final="42")
next_request = next_fn.call_args[0][0]
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter2(**kwargs):
return final
mock_am.react.side_effect = mock_react_iter2
await service.agent_request(next_request, respond, next_fn, flow)
# Collect and verify DAG
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
final_ids = _find_by_type(dag, TG_CONCLUSION)
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
# Full chain:
# session → pattern-decision
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
# pattern-decision → iteration(1)
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
# iteration(1) → observation(1)
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
# observation(1) → final
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
class TestAgentPlanDagStructure:
"""
Via service.agent_request():
session pattern-decision plan step-result synthesis
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
# Mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query KB"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="Found it")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"knowledge-query": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
# Mock prompt client
mock_prompt_client = AsyncMock()
call_count = 0
async def mock_prompt(id, variables=None, **kwargs):
nonlocal call_count
call_count += 1
if id == "plan-create":
return PromptResult(
response_type="jsonl",
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
)
elif id == "plan-step-execute":
return PromptResult(
response_type="json",
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
)
elif id == "plan-synthesise":
return PromptResult(response_type="text", text="Final answer.")
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt.side_effect = mock_prompt
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
session_id = str(uuid.uuid4())
# Iteration 1: planning
request1 = AgentRequest(
question="Test?",
user="testuser",
collection="default",
streaming=False,
session_id=session_id,
pattern="plan-then-execute",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# Iteration 2: execute step (next_fn was called with updated request)
assert next_fn.called
next_request = next_fn.call_args[0][0]
# Iteration 3: all steps done → synthesis
# Simulate completed step in history
next_request.history[-1].plan[0].status = "completed"
next_request.history[-1].plan[0].result = "Found it"
await service.agent_request(next_request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(plan_ids) == 1
assert len(synthesis_ids) == 1
# Chain: session → pattern-decision → plan → ... → synthesis
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
class TestAgentSupervisorDagStructure:
"""
Via service.agent_request():
session pattern-decision decomposition (fan-out)
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = _make_processor()
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B"],
)
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = AgentRequest(
question="Research quantum computing",
user="testuser",
collection="default",
streaming=False,
session_id=str(uuid.uuid4()),
pattern="supervisor",
history=[],
)
await service.agent_request(request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(decomp_ids) == 1
# Chain: session → pattern-decision → decomposition
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
# Fan-out should have been called
assert next_fn.call_count == 2 # One per goal

View file

@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY) assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE) assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
def test_chunk_entity_has_chunk_type(self): def test_chunk_entity_has_message_type(self):
triples = derived_entity_triples( triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI, self.ENTITY_URI, self.PARENT_URI,
"chunker", "1.0", "chunker", "1.0",

View file

@ -304,14 +304,14 @@ class TestStreamingTypes:
assert chunk.content == "thinking..." assert chunk.content == "thinking..."
assert chunk.end_of_message is False assert chunk.end_of_message is False
assert chunk.chunk_type == "thought" assert chunk.message_type == "thought"
def test_agent_observation_creation(self): def test_agent_observation_creation(self):
"""Test creating AgentObservation chunk""" """Test creating AgentObservation chunk"""
chunk = AgentObservation(content="observing...", end_of_message=False) chunk = AgentObservation(content="observing...", end_of_message=False)
assert chunk.content == "observing..." assert chunk.content == "observing..."
assert chunk.chunk_type == "observation" assert chunk.message_type == "observation"
def test_agent_answer_creation(self): def test_agent_answer_creation(self):
"""Test creating AgentAnswer chunk""" """Test creating AgentAnswer chunk"""
@ -324,7 +324,7 @@ class TestStreamingTypes:
assert chunk.content == "answer" assert chunk.content == "answer"
assert chunk.end_of_message is True assert chunk.end_of_message is True
assert chunk.end_of_dialog is True assert chunk.end_of_dialog is True
assert chunk.chunk_type == "final-answer" assert chunk.message_type == "final-answer"
def test_rag_chunk_creation(self): def test_rag_chunk_creation(self):
"""Test creating RAGChunk""" """Test creating RAGChunk"""

View file

@ -9,7 +9,7 @@ from .streaming_assertions import (
assert_streaming_content_matches, assert_streaming_content_matches,
assert_no_empty_chunks, assert_no_empty_chunks,
assert_streaming_error_handled, assert_streaming_error_handled,
assert_chunk_types_valid, assert_message_types_valid,
assert_streaming_latency_acceptable, assert_streaming_latency_acceptable,
assert_callback_invoked, assert_callback_invoked,
) )
@ -23,7 +23,7 @@ __all__ = [
"assert_streaming_content_matches", "assert_streaming_content_matches",
"assert_no_empty_chunks", "assert_no_empty_chunks",
"assert_streaming_error_handled", "assert_streaming_error_handled",
"assert_chunk_types_valid", "assert_message_types_valid",
"assert_streaming_latency_acceptable", "assert_streaming_latency_acceptable",
"assert_callback_invoked", "assert_callback_invoked",
] ]

View file

@ -20,14 +20,14 @@ def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None" assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"): def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "message_type"):
""" """
Assert that streaming chunks follow an expected sequence. Assert that streaming chunks follow an expected sequence.
Args: Args:
chunks: List of chunk dictionaries chunks: List of chunk dictionaries
expected_sequence: Expected sequence of chunk types/values expected_sequence: Expected sequence of chunk types/values
key: Dictionary key to check (default: "chunk_type") key: Dictionary key to check (default: "message_type")
""" """
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk] actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
assert actual_sequence == expected_sequence, \ assert actual_sequence == expected_sequence, \
@ -39,7 +39,7 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
Assert that agent streaming chunks have valid structure. Assert that agent streaming chunks have valid structure.
Validates: Validates:
- All chunks have chunk_type field - All chunks have message_type field
- All chunks have content field - All chunks have content field
- All chunks have end_of_message field - All chunks have end_of_message field
- All chunks have end_of_dialog field - All chunks have end_of_dialog field
@ -51,15 +51,15 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
assert len(chunks) > 0, "Expected at least one chunk" assert len(chunks) > 0, "Expected at least one chunk"
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type" assert "message_type" in chunk, f"Chunk {i} missing message_type"
assert "content" in chunk, f"Chunk {i} missing content" assert "content" in chunk, f"Chunk {i} missing content"
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message" assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog" assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
# Validate chunk_type values # Validate message_type values
valid_types = ["thought", "action", "observation", "final-answer"] valid_types = ["thought", "action", "observation", "final-answer"]
assert chunk["chunk_type"] in valid_types, \ assert chunk["message_type"] in valid_types, \
f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}" f"Invalid message_type '{chunk['message_type']}' at index {i}"
# Last chunk should signal end of dialog # Last chunk should signal end of dialog
assert chunks[-1]["end_of_dialog"] is True, \ assert chunks[-1]["end_of_dialog"] is True, \
@ -175,7 +175,7 @@ def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str
"Error chunk should have completion flag set to True" "Error chunk should have completion flag set to True"
def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"): def assert_message_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "message_type"):
""" """
Assert that all chunk types are from a valid set. Assert that all chunk types are from a valid set.
@ -185,9 +185,9 @@ def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str
type_key: Dictionary key for chunk type type_key: Dictionary key for chunk type
""" """
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
chunk_type = chunk.get(type_key) message_type = chunk.get(type_key)
assert chunk_type in valid_types, \ assert message_type in valid_types, \
f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}" f"Chunk {i} has invalid type '{message_type}', expected one of {valid_types}"
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0): def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):

View file

@ -178,24 +178,23 @@ class AsyncSocketClient:
def _parse_chunk(self, resp: Dict[str, Any]): def _parse_chunk(self, resp: Dict[str, Any]):
"""Parse response chunk into appropriate type. Returns None for non-content messages.""" """Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type") message_type = resp.get("message_type")
# Handle new GraphRAG message format with message_type # Handle new GraphRAG message format with message_type
if message_type == "provenance": if message_type == "provenance":
return None return None
if chunk_type == "thought": if message_type == "thought":
return AgentThought( return AgentThought(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False) end_of_message=resp.get("end_of_message", False)
) )
elif chunk_type == "observation": elif message_type == "observation":
return AgentObservation( return AgentObservation(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False) end_of_message=resp.get("end_of_message", False)
) )
elif chunk_type == "answer" or chunk_type == "final-answer": elif message_type == "answer" or message_type == "final-answer":
return AgentAnswer( return AgentAnswer(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
@ -204,7 +203,7 @@ class AsyncSocketClient:
out_token=resp.get("out_token"), out_token=resp.get("out_token"),
model=resp.get("model"), model=resp.get("model"),
) )
elif chunk_type == "action": elif message_type == "action":
return AgentThought( return AgentThought(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False) end_of_message=resp.get("end_of_message", False)

View file

@ -360,34 +360,26 @@ class SocketClient:
def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]: def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]:
"""Parse response chunk into appropriate type. Returns None for non-content messages.""" """Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type") message_type = resp.get("message_type")
# Handle GraphRAG/DocRAG message format with message_type
if message_type == "explain": if message_type == "explain":
if include_provenance: if include_provenance:
return self._build_provenance_event(resp) return self._build_provenance_event(resp)
return None return None
# Handle Agent message format with chunk_type="explain" if message_type == "thought":
if chunk_type == "explain":
if include_provenance:
return self._build_provenance_event(resp)
return None
if chunk_type == "thought":
return AgentThought( return AgentThought(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""), message_id=resp.get("message_id", ""),
) )
elif chunk_type == "observation": elif message_type == "observation":
return AgentObservation( return AgentObservation(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""), message_id=resp.get("message_id", ""),
) )
elif chunk_type == "answer" or chunk_type == "final-answer": elif message_type == "answer" or message_type == "final-answer":
return AgentAnswer( return AgentAnswer(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
@ -397,7 +389,7 @@ class SocketClient:
out_token=resp.get("out_token"), out_token=resp.get("out_token"),
model=resp.get("model"), model=resp.get("model"),
) )
elif chunk_type == "action": elif message_type == "action":
return AgentThought( return AgentThought(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False) end_of_message=resp.get("end_of_message", False)

View file

@ -149,10 +149,10 @@ class AgentThought(StreamingChunk):
Attributes: Attributes:
content: Agent's thought text content: Agent's thought text
end_of_message: True if this completes the current thought end_of_message: True if this completes the current thought
chunk_type: Always "thought" message_type: Always "thought"
message_id: Provenance URI of the entity being built message_id: Provenance URI of the entity being built
""" """
chunk_type: str = "thought" message_type: str = "thought"
message_id: str = "" message_id: str = ""
@dataclasses.dataclass @dataclasses.dataclass
@ -166,10 +166,10 @@ class AgentObservation(StreamingChunk):
Attributes: Attributes:
content: Observation text describing tool results content: Observation text describing tool results
end_of_message: True if this completes the current observation end_of_message: True if this completes the current observation
chunk_type: Always "observation" message_type: Always "observation"
message_id: Provenance URI of the entity being built message_id: Provenance URI of the entity being built
""" """
chunk_type: str = "observation" message_type: str = "observation"
message_id: str = "" message_id: str = ""
@dataclasses.dataclass @dataclasses.dataclass
@ -184,9 +184,9 @@ class AgentAnswer(StreamingChunk):
content: Answer text content: Answer text
end_of_message: True if this completes the current answer segment end_of_message: True if this completes the current answer segment
end_of_dialog: True if this completes the entire agent interaction end_of_dialog: True if this completes the entire agent interaction
chunk_type: Always "final-answer" message_type: Always "final-answer"
""" """
chunk_type: str = "final-answer" message_type: str = "final-answer"
end_of_dialog: bool = False end_of_dialog: bool = False
message_id: str = "" message_id: str = ""
in_token: Optional[int] = None in_token: Optional[int] = None
@ -208,9 +208,9 @@ class RAGChunk(StreamingChunk):
in_token: Input token count (populated on the final chunk, 0 otherwise) in_token: Input token count (populated on the final chunk, 0 otherwise)
out_token: Output token count (populated on the final chunk, 0 otherwise) out_token: Output token count (populated on the final chunk, 0 otherwise)
model: Model identifier (populated on the final chunk, empty otherwise) model: Model identifier (populated on the final chunk, empty otherwise)
chunk_type: Always "rag" message_type: Always "rag"
""" """
chunk_type: str = "rag" message_type: str = "rag"
end_of_stream: bool = False end_of_stream: bool = False
error: Optional[Dict[str, str]] = None error: Optional[Dict[str, str]] = None
in_token: Optional[int] = None in_token: Optional[int] = None

View file

@ -30,19 +30,19 @@ class AgentClient(RequestResponse):
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
# Handle thought chunks # Handle thought chunks
if resp.chunk_type == 'thought': if resp.message_type == 'thought':
if think: if think:
await think(resp.content, resp.end_of_message) await think(resp.content, resp.end_of_message)
return False # Continue receiving return False # Continue receiving
# Handle observation chunks # Handle observation chunks
if resp.chunk_type == 'observation': if resp.message_type == 'observation':
if observe: if observe:
await observe(resp.content, resp.end_of_message) await observe(resp.content, resp.end_of_message)
return False # Continue receiving return False # Continue receiving
# Handle answer chunks # Handle answer chunks
if resp.chunk_type == 'answer': if resp.message_type == 'answer':
if resp.content: if resp.content:
accumulated_answer.append(resp.content) accumulated_answer.append(resp.content)
if answer_callback: if answer_callback:

View file

@ -58,23 +58,23 @@ class AgentClient(BaseClient):
def inspect(x): def inspect(x):
# Handle errors # Handle errors
if x.chunk_type == 'error' or x.error: if x.message_type == 'error' or x.error:
if error_callback: if error_callback:
error_callback(x.content or (x.error.message if x.error else "")) error_callback(x.content or (x.error.message if x.error else ""))
# Continue to check end_of_dialog # Continue to check end_of_dialog
# Handle thought chunks # Handle thought chunks
elif x.chunk_type == 'thought': elif x.message_type == 'thought':
if think: if think:
think(x.content, x.end_of_message) think(x.content, x.end_of_message)
# Handle observation chunks # Handle observation chunks
elif x.chunk_type == 'observation': elif x.message_type == 'observation':
if observe: if observe:
observe(x.content, x.end_of_message) observe(x.content, x.end_of_message)
# Handle answer chunks # Handle answer chunks
elif x.chunk_type == 'answer': elif x.message_type == 'answer':
if x.content: if x.content:
accumulated_answer.append(x.content) accumulated_answer.append(x.content)
if answer_callback: if answer_callback:

View file

@ -60,8 +60,8 @@ class AgentResponseTranslator(MessageTranslator):
def encode(self, obj: AgentResponse) -> Dict[str, Any]: def encode(self, obj: AgentResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.chunk_type: if obj.message_type:
result["chunk_type"] = obj.chunk_type result["message_type"] = obj.message_type
if obj.content: if obj.content:
result["content"] = obj.content result["content"] = obj.content
result["end_of_message"] = getattr(obj, "end_of_message", False) result["end_of_message"] = getattr(obj, "end_of_message", False)

View file

@ -59,6 +59,7 @@ from . uris import (
agent_plan_uri, agent_plan_uri,
agent_step_result_uri, agent_step_result_uri,
agent_synthesis_uri, agent_synthesis_uri,
agent_pattern_decision_uri,
# Document RAG provenance URIs # Document RAG provenance URIs
docrag_question_uri, docrag_question_uri,
docrag_grounding_uri, docrag_grounding_uri,
@ -102,6 +103,11 @@ from . namespaces import (
# Agent provenance predicates # Agent provenance predicates
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_SUBAGENT_GOAL, TG_PLAN_STEP, TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
TG_IN_TOKEN, TG_OUT_TOKEN,
TG_ERROR_TYPE,
# Orchestrator entity types # Orchestrator entity types
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
# Document reference predicate # Document reference predicate
@ -141,6 +147,7 @@ from . agent import (
agent_plan_triples, agent_plan_triples,
agent_step_result_triples, agent_step_result_triples,
agent_synthesis_triples, agent_synthesis_triples,
agent_pattern_decision_triples,
) )
# Vocabulary bootstrap # Vocabulary bootstrap
@ -182,6 +189,7 @@ __all__ = [
"agent_plan_uri", "agent_plan_uri",
"agent_step_result_uri", "agent_step_result_uri",
"agent_synthesis_uri", "agent_synthesis_uri",
"agent_pattern_decision_uri",
# Document RAG provenance URIs # Document RAG provenance URIs
"docrag_question_uri", "docrag_question_uri",
"docrag_grounding_uri", "docrag_grounding_uri",
@ -218,6 +226,11 @@ __all__ = [
# Agent provenance predicates # Agent provenance predicates
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION",
"TG_SUBAGENT_GOAL", "TG_PLAN_STEP", "TG_SUBAGENT_GOAL", "TG_PLAN_STEP",
"TG_TOOL_CANDIDATE", "TG_TERMINATION_REASON",
"TG_STEP_NUMBER", "TG_PATTERN_DECISION", "TG_PATTERN", "TG_TASK_TYPE",
"TG_LLM_DURATION_MS", "TG_TOOL_DURATION_MS", "TG_TOOL_ERROR",
"TG_IN_TOKEN", "TG_OUT_TOKEN",
"TG_ERROR_TYPE",
# Orchestrator entity types # Orchestrator entity types
"TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT", "TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT",
# Document reference predicate # Document reference predicate
@ -249,6 +262,7 @@ __all__ = [
"agent_plan_triples", "agent_plan_triples",
"agent_step_result_triples", "agent_step_result_triples",
"agent_synthesis_triples", "agent_synthesis_triples",
"agent_pattern_decision_triples",
# Utility # Utility
"set_graph", "set_graph",
# Vocabulary # Vocabulary

View file

@ -29,6 +29,11 @@ from . namespaces import (
TG_AGENT_QUESTION, TG_AGENT_QUESTION,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP, TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
TG_ERROR_TYPE,
TG_IN_TOKEN, TG_OUT_TOKEN, TG_LLM_MODEL,
) )
@ -47,6 +52,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple:
return Triple(s=_iri(s), p=_iri(p), o=o_term) return Triple(s=_iri(s), p=_iri(p), o=o_term)
def _append_token_triples(triples, uri, in_token=None, out_token=None,
model=None):
"""Append in_token/out_token/model triples when values are present."""
if in_token is not None:
triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token))))
if out_token is not None:
triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token))))
if model is not None:
triples.append(_triple(uri, TG_LLM_MODEL, _literal(model)))
def agent_session_triples( def agent_session_triples(
session_uri: str, session_uri: str,
query: str, query: str,
@ -90,6 +106,43 @@ def agent_session_triples(
return triples return triples
def agent_pattern_decision_triples(
uri: str,
session_uri: str,
pattern: str,
task_type: str = "",
) -> List[Triple]:
"""
Build triples for a meta-router pattern decision.
Creates:
- Entity declaration with tg:PatternDecision type
- wasDerivedFrom link to session
- Pattern and task type predicates
Args:
uri: URI of this decision (from agent_pattern_decision_uri)
session_uri: URI of the parent session
pattern: Selected execution pattern (e.g. "react", "plan-then-execute")
task_type: Identified task type (e.g. "general", "research")
Returns:
List of Triple objects
"""
triples = [
_triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(uri, RDF_TYPE, _iri(TG_PATTERN_DECISION)),
_triple(uri, RDFS_LABEL, _literal(f"Pattern: {pattern}")),
_triple(uri, TG_PATTERN, _literal(pattern)),
_triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)),
]
if task_type:
triples.append(_triple(uri, TG_TASK_TYPE, _literal(task_type)))
return triples
def agent_iteration_triples( def agent_iteration_triples(
iteration_uri: str, iteration_uri: str,
question_uri: Optional[str] = None, question_uri: Optional[str] = None,
@ -98,6 +151,12 @@ def agent_iteration_triples(
arguments: Dict[str, Any] = None, arguments: Dict[str, Any] = None,
thought_uri: Optional[str] = None, thought_uri: Optional[str] = None,
thought_document_id: Optional[str] = None, thought_document_id: Optional[str] = None,
tool_candidates: Optional[List[str]] = None,
step_number: Optional[int] = None,
llm_duration_ms: Optional[int] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for one agent iteration (Analysis+ToolUse). Build triples for one agent iteration (Analysis+ToolUse).
@ -106,6 +165,7 @@ def agent_iteration_triples(
- Entity declaration with tg:Analysis and tg:ToolUse types - Entity declaration with tg:Analysis and tg:ToolUse types
- wasDerivedFrom link to question (if first iteration) or previous - wasDerivedFrom link to question (if first iteration) or previous
- Action and arguments metadata - Action and arguments metadata
- Tool candidates (names of tools visible to the LLM)
- Thought sub-entity (tg:Reflection, tg:Thought) with librarian document - Thought sub-entity (tg:Reflection, tg:Thought) with librarian document
Args: Args:
@ -116,6 +176,7 @@ def agent_iteration_triples(
arguments: Arguments passed to the tool (will be JSON-encoded) arguments: Arguments passed to the tool (will be JSON-encoded)
thought_uri: URI for the thought sub-entity thought_uri: URI for the thought sub-entity
thought_document_id: Document URI for thought in librarian thought_document_id: Document URI for thought in librarian
tool_candidates: List of tool names available to the LLM
Returns: Returns:
List of Triple objects List of Triple objects
@ -132,6 +193,23 @@ def agent_iteration_triples(
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
] ]
if tool_candidates:
for name in tool_candidates:
triples.append(
_triple(iteration_uri, TG_TOOL_CANDIDATE, _literal(name))
)
if step_number is not None:
triples.append(
_triple(iteration_uri, TG_STEP_NUMBER, _literal(str(step_number)))
)
if llm_duration_ms is not None:
triples.append(
_triple(iteration_uri, TG_LLM_DURATION_MS,
_literal(str(llm_duration_ms)))
)
if question_uri: if question_uri:
triples.append( triples.append(
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri)) _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri))
@ -155,6 +233,8 @@ def agent_iteration_triples(
_triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id)) _triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id))
) )
_append_token_triples(triples, iteration_uri, in_token, out_token, model)
return triples return triples
@ -162,6 +242,8 @@ def agent_observation_triples(
observation_uri: str, observation_uri: str,
iteration_uri: str, iteration_uri: str,
document_id: Optional[str] = None, document_id: Optional[str] = None,
tool_duration_ms: Optional[int] = None,
tool_error: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for an agent observation (standalone entity). Build triples for an agent observation (standalone entity).
@ -170,11 +252,15 @@ def agent_observation_triples(
- Entity declaration with prov:Entity and tg:Observation types - Entity declaration with prov:Entity and tg:Observation types
- wasDerivedFrom link to the iteration (Analysis+ToolUse) - wasDerivedFrom link to the iteration (Analysis+ToolUse)
- Document reference to librarian (if provided) - Document reference to librarian (if provided)
- Tool execution duration (if provided)
- Tool error message (if the tool failed)
Args: Args:
observation_uri: URI of the observation entity observation_uri: URI of the observation entity
iteration_uri: URI of the iteration this observation derives from iteration_uri: URI of the iteration this observation derives from
document_id: Librarian document ID for the observation content document_id: Librarian document ID for the observation content
tool_duration_ms: Tool execution time in milliseconds
tool_error: Error message if the tool failed
Returns: Returns:
List of Triple objects List of Triple objects
@ -191,6 +277,20 @@ def agent_observation_triples(
_triple(observation_uri, TG_DOCUMENT, _iri(document_id)) _triple(observation_uri, TG_DOCUMENT, _iri(document_id))
) )
if tool_duration_ms is not None:
triples.append(
_triple(observation_uri, TG_TOOL_DURATION_MS,
_literal(str(tool_duration_ms)))
)
if tool_error:
triples.append(
_triple(observation_uri, TG_TOOL_ERROR, _literal(tool_error))
)
triples.append(
_triple(observation_uri, RDF_TYPE, _iri(TG_ERROR_TYPE))
)
return triples return triples
@ -199,6 +299,10 @@ def agent_final_triples(
question_uri: Optional[str] = None, question_uri: Optional[str] = None,
previous_uri: Optional[str] = None, previous_uri: Optional[str] = None,
document_id: Optional[str] = None, document_id: Optional[str] = None,
termination_reason: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for an agent final answer (Conclusion). Build triples for an agent final answer (Conclusion).
@ -208,12 +312,15 @@ def agent_final_triples(
- wasGeneratedBy link to question (if no iterations) - wasGeneratedBy link to question (if no iterations)
- wasDerivedFrom link to last iteration (if iterations exist) - wasDerivedFrom link to last iteration (if iterations exist)
- Document reference to librarian - Document reference to librarian
- Termination reason (why the agent loop stopped)
Args: Args:
final_uri: URI of the final answer (from agent_final_uri) final_uri: URI of the final answer (from agent_final_uri)
question_uri: URI of the question activity (if no iterations) question_uri: URI of the question activity (if no iterations)
previous_uri: URI of the last iteration (if iterations exist) previous_uri: URI of the last iteration (if iterations exist)
document_id: Librarian document ID for the answer content document_id: Librarian document ID for the answer content
termination_reason: Why the loop stopped, e.g. "final-answer",
"max-iterations", "error"
Returns: Returns:
List of Triple objects List of Triple objects
@ -237,6 +344,14 @@ def agent_final_triples(
if document_id: if document_id:
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
if termination_reason:
triples.append(
_triple(final_uri, TG_TERMINATION_REASON,
_literal(termination_reason))
)
_append_token_triples(triples, final_uri, in_token, out_token, model)
return triples return triples
@ -244,6 +359,9 @@ def agent_decomposition_triples(
uri: str, uri: str,
session_uri: str, session_uri: str,
goals: List[str], goals: List[str],
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
"""Build triples for a supervisor decomposition step.""" """Build triples for a supervisor decomposition step."""
triples = [ triples = [
@ -255,6 +373,7 @@ def agent_decomposition_triples(
] ]
for goal in goals: for goal in goals:
triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal))) triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal)))
_append_token_triples(triples, uri, in_token, out_token, model)
return triples return triples
@ -282,6 +401,9 @@ def agent_plan_triples(
uri: str, uri: str,
session_uri: str, session_uri: str,
steps: List[str], steps: List[str],
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
"""Build triples for a plan-then-execute plan.""" """Build triples for a plan-then-execute plan."""
triples = [ triples = [
@ -293,6 +415,7 @@ def agent_plan_triples(
] ]
for step in steps: for step in steps:
triples.append(_triple(uri, TG_PLAN_STEP, _literal(step))) triples.append(_triple(uri, TG_PLAN_STEP, _literal(step)))
_append_token_triples(triples, uri, in_token, out_token, model)
return triples return triples
@ -301,6 +424,9 @@ def agent_step_result_triples(
plan_uri: str, plan_uri: str,
goal: str, goal: str,
document_id: Optional[str] = None, document_id: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
"""Build triples for a plan step result.""" """Build triples for a plan step result."""
triples = [ triples = [
@ -313,6 +439,7 @@ def agent_step_result_triples(
] ]
if document_id: if document_id:
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
_append_token_triples(triples, uri, in_token, out_token, model)
return triples return triples
@ -320,6 +447,10 @@ def agent_synthesis_triples(
uri: str, uri: str,
previous_uris, previous_uris,
document_id: Optional[str] = None, document_id: Optional[str] = None,
termination_reason: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
"""Build triples for a synthesis answer. """Build triples for a synthesis answer.
@ -327,6 +458,8 @@ def agent_synthesis_triples(
uri: URI of the synthesis entity uri: URI of the synthesis entity
previous_uris: Single URI string or list of URIs to derive from previous_uris: Single URI string or list of URIs to derive from
document_id: Librarian document ID for the answer content document_id: Librarian document ID for the answer content
termination_reason: Why the agent loop stopped
in_token/out_token/model: Token usage for the synthesis LLM call
""" """
triples = [ triples = [
_triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
@ -342,4 +475,12 @@ def agent_synthesis_triples(
if document_id: if document_id:
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
if termination_reason:
triples.append(
_triple(uri, TG_TERMINATION_REASON, _literal(termination_reason))
)
_append_token_triples(triples, uri, in_token, out_token, model)
return triples return triples

View file

@ -119,6 +119,18 @@ TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity
TG_SUBAGENT_GOAL = TG + "subagentGoal" # Goal string on Decomposition/Finding TG_SUBAGENT_GOAL = TG + "subagentGoal" # Goal string on Decomposition/Finding
TG_PLAN_STEP = TG + "planStep" # Step goal string on Plan/StepResult TG_PLAN_STEP = TG + "planStep" # Step goal string on Plan/StepResult
TG_TOOL_CANDIDATE = TG + "toolCandidate" # Tool name on Analysis events
TG_TERMINATION_REASON = TG + "terminationReason" # Why the agent loop stopped
TG_STEP_NUMBER = TG + "stepNumber" # Explicit step counter on iteration events
TG_PATTERN_DECISION = TG + "PatternDecision" # Meta-router routing decision entity type
TG_PATTERN = TG + "pattern" # Selected execution pattern
TG_TASK_TYPE = TG + "taskType" # Identified task type
TG_LLM_DURATION_MS = TG + "llmDurationMs" # LLM call duration in milliseconds
TG_TOOL_DURATION_MS = TG + "toolDurationMs" # Tool execution duration in milliseconds
TG_TOOL_ERROR = TG + "toolError" # Error message from a failed tool execution
TG_ERROR_TYPE = TG + "Error" # Mixin type for failure events
TG_IN_TOKEN = TG + "inToken" # Input token count for an LLM call
TG_OUT_TOKEN = TG + "outToken" # Output token count for an LLM call
# Named graph URIs for RDF datasets # Named graph URIs for RDF datasets
# These separate different types of data while keeping them in the same collection # These separate different types of data while keeping them in the same collection

View file

@ -34,6 +34,8 @@ from . namespaces import (
TG_ANSWER_TYPE, TG_ANSWER_TYPE,
# Question subtypes # Question subtypes
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
# Token usage
TG_IN_TOKEN, TG_OUT_TOKEN,
) )
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
@ -74,6 +76,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple:
return Triple(s=_iri(s), p=_iri(p), o=o_term) return Triple(s=_iri(s), p=_iri(p), o=o_term)
def _append_token_triples(triples, uri, in_token=None, out_token=None,
model=None):
"""Append in_token/out_token/model triples when values are present."""
if in_token is not None:
triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token))))
if out_token is not None:
triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token))))
if model is not None:
triples.append(_triple(uri, TG_LLM_MODEL, _literal(model)))
def document_triples( def document_triples(
doc_uri: str, doc_uri: str,
title: Optional[str] = None, title: Optional[str] = None,
@ -396,6 +409,9 @@ def grounding_triples(
grounding_uri: str, grounding_uri: str,
question_uri: str, question_uri: str,
concepts: List[str], concepts: List[str],
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for a grounding entity (concept decomposition of query). Build triples for a grounding entity (concept decomposition of query).
@ -423,6 +439,8 @@ def grounding_triples(
for concept in concepts: for concept in concepts:
triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept))) triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept)))
_append_token_triples(triples, grounding_uri, in_token, out_token, model)
return triples return triples
@ -485,6 +503,9 @@ def focus_triples(
exploration_uri: str, exploration_uri: str,
selected_edges_with_reasoning: List[dict], selected_edges_with_reasoning: List[dict],
session_id: str = "", session_id: str = "",
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for a focus entity (selected edges with reasoning). Build triples for a focus entity (selected edges with reasoning).
@ -543,6 +564,8 @@ def focus_triples(
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) _triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
) )
_append_token_triples(triples, focus_uri, in_token, out_token, model)
return triples return triples
@ -550,6 +573,9 @@ def synthesis_triples(
synthesis_uri: str, synthesis_uri: str,
focus_uri: str, focus_uri: str,
document_id: Optional[str] = None, document_id: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for a synthesis entity (final answer). Build triples for a synthesis entity (final answer).
@ -578,6 +604,8 @@ def synthesis_triples(
if document_id: if document_id:
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
_append_token_triples(triples, synthesis_uri, in_token, out_token, model)
return triples return triples
@ -674,6 +702,9 @@ def docrag_synthesis_triples(
synthesis_uri: str, synthesis_uri: str,
exploration_uri: str, exploration_uri: str,
document_id: Optional[str] = None, document_id: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]: ) -> List[Triple]:
""" """
Build triples for a document RAG synthesis entity (final answer). Build triples for a document RAG synthesis entity (final answer).
@ -702,4 +733,6 @@ def docrag_synthesis_triples(
if document_id: if document_id:
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
_append_token_triples(triples, synthesis_uri, in_token, out_token, model)
return triples return triples

View file

@ -259,6 +259,11 @@ def agent_synthesis_uri(session_id: str) -> str:
return f"urn:trustgraph:agent:{session_id}/synthesis" return f"urn:trustgraph:agent:{session_id}/synthesis"
def agent_pattern_decision_uri(session_id: str) -> str:
"""Generate URI for a meta-router pattern decision."""
return f"urn:trustgraph:agent:{session_id}/pattern-decision"
# Document RAG provenance URIs # Document RAG provenance URIs
# These URIs use the urn:trustgraph:docrag: namespace to distinguish # These URIs use the urn:trustgraph:docrag: namespace to distinguish
# document RAG provenance from graph RAG provenance # document RAG provenance from graph RAG provenance

View file

@ -51,8 +51,8 @@ class AgentRequest:
@dataclass @dataclass
class AgentResponse: class AgentResponse:
# Streaming-first design # Streaming-first design
chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error" message_type: str = "" # "thought", "action", "observation", "answer", "explain", "error"
content: str = "" # The actual content (interpretation depends on chunk_type) content: str = "" # The actual content (interpretation depends on message_type)
end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
end_of_dialog: bool = False # Entire agent dialog is complete end_of_dialog: bool = False # Entire agent dialog is complete

View file

@ -126,7 +126,7 @@ def question_explainable(
try: try:
# Track last chunk type for formatting # Track last chunk type for formatting
last_chunk_type = None last_message_type = None
current_outputter = None current_outputter = None
# Stream agent with explainability - process events as they arrive # Stream agent with explainability - process events as they arrive
@ -138,7 +138,7 @@ def question_explainable(
group=group, group=group,
): ):
if isinstance(item, AgentThought): if isinstance(item, AgentThought):
if last_chunk_type != "thought": if last_message_type != "thought":
if current_outputter: if current_outputter:
current_outputter.__exit__(None, None, None) current_outputter.__exit__(None, None, None)
current_outputter = None current_outputter = None
@ -146,7 +146,7 @@ def question_explainable(
if verbose: if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f914 ") current_outputter = Outputter(width=78, prefix="\U0001f914 ")
current_outputter.__enter__() current_outputter.__enter__()
last_chunk_type = "thought" last_message_type = "thought"
if current_outputter: if current_outputter:
current_outputter.output(item.content) current_outputter.output(item.content)
if current_outputter.word_buffer: if current_outputter.word_buffer:
@ -155,7 +155,7 @@ def question_explainable(
current_outputter.word_buffer = "" current_outputter.word_buffer = ""
elif isinstance(item, AgentObservation): elif isinstance(item, AgentObservation):
if last_chunk_type != "observation": if last_message_type != "observation":
if current_outputter: if current_outputter:
current_outputter.__exit__(None, None, None) current_outputter.__exit__(None, None, None)
current_outputter = None current_outputter = None
@ -163,7 +163,7 @@ def question_explainable(
if verbose: if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
current_outputter.__enter__() current_outputter.__enter__()
last_chunk_type = "observation" last_message_type = "observation"
if current_outputter: if current_outputter:
current_outputter.output(item.content) current_outputter.output(item.content)
if current_outputter.word_buffer: if current_outputter.word_buffer:
@ -172,12 +172,12 @@ def question_explainable(
current_outputter.word_buffer = "" current_outputter.word_buffer = ""
elif isinstance(item, AgentAnswer): elif isinstance(item, AgentAnswer):
if last_chunk_type != "answer": if last_message_type != "answer":
if current_outputter: if current_outputter:
current_outputter.__exit__(None, None, None) current_outputter.__exit__(None, None, None)
current_outputter = None current_outputter = None
print() print()
last_chunk_type = "answer" last_message_type = "answer"
# Print answer content directly # Print answer content directly
print(item.content, end="", flush=True) print(item.content, end="", flush=True)
@ -261,7 +261,7 @@ def question_explainable(
current_outputter = None current_outputter = None
# Final newline if we ended with answer # Final newline if we ended with answer
if last_chunk_type == "answer": if last_message_type == "answer":
print() print()
finally: finally:
@ -322,16 +322,16 @@ def question(
# Handle streaming response # Handle streaming response
if streaming: if streaming:
# Track last chunk type and current outputter for streaming # Track last chunk type and current outputter for streaming
last_chunk_type = None last_message_type = None
current_outputter = None current_outputter = None
last_answer_chunk = None last_answer_chunk = None
for chunk in response: for chunk in response:
chunk_type = chunk.chunk_type message_type = chunk.message_type
content = chunk.content content = chunk.content
# Check if we're switching to a new message type # Check if we're switching to a new message type
if last_chunk_type != chunk_type: if last_message_type != message_type:
# Close previous outputter if exists # Close previous outputter if exists
if current_outputter: if current_outputter:
current_outputter.__exit__(None, None, None) current_outputter.__exit__(None, None, None)
@ -339,15 +339,15 @@ def question(
print() # Blank line between message types print() # Blank line between message types
# Create new outputter for new message type # Create new outputter for new message type
if chunk_type == "thought" and verbose: if message_type == "thought" and verbose:
current_outputter = Outputter(width=78, prefix="\U0001f914 ") current_outputter = Outputter(width=78, prefix="\U0001f914 ")
current_outputter.__enter__() current_outputter.__enter__()
elif chunk_type == "observation" and verbose: elif message_type == "observation" and verbose:
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
current_outputter.__enter__() current_outputter.__enter__()
# For answer, don't use Outputter - just print as-is # For answer, don't use Outputter - just print as-is
last_chunk_type = chunk_type last_message_type = message_type
# Output the chunk # Output the chunk
if current_outputter: if current_outputter:
@ -357,7 +357,7 @@ def question(
print(current_outputter.word_buffer, end="", flush=True) print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer) current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = "" current_outputter.word_buffer = ""
elif chunk_type == "final-answer": elif message_type == "final-answer":
print(content, end="", flush=True) print(content, end="", flush=True)
last_answer_chunk = chunk last_answer_chunk = chunk
@ -366,7 +366,7 @@ def question(
current_outputter.__exit__(None, None, None) current_outputter.__exit__(None, None, None)
current_outputter = None current_outputter = None
# Add final newline if we were outputting answer # Add final newline if we were outputting answer
elif last_chunk_type == "final-answer": elif last_message_type == "final-answer":
print() print()
if show_usage and last_answer_chunk: if show_usage and last_answer_chunk:
@ -382,17 +382,17 @@ def question(
# so we iterate through the chunks (which are complete messages, not text chunks) # so we iterate through the chunks (which are complete messages, not text chunks)
for chunk in response: for chunk in response:
# Display thoughts if verbose # Display thoughts if verbose
if chunk.chunk_type == "thought" and verbose: if chunk.message_type == "thought" and verbose:
output(wrap(chunk.content), "\U0001f914 ") output(wrap(chunk.content), "\U0001f914 ")
print() print()
# Display observations if verbose # Display observations if verbose
elif chunk.chunk_type == "observation" and verbose: elif chunk.message_type == "observation" and verbose:
output(wrap(chunk.content), "\U0001f4a1 ") output(wrap(chunk.content), "\U0001f4a1 ")
print() print()
# Display answer # Display answer
elif chunk.chunk_type == "final-answer" or chunk.chunk_type == "answer": elif chunk.message_type == "final-answer" or chunk.message_type == "answer":
print(chunk.content) print(chunk.content)
finally: finally:

View file

@ -25,6 +25,7 @@ from trustgraph.provenance import (
agent_plan_uri, agent_plan_uri,
agent_step_result_uri, agent_step_result_uri,
agent_synthesis_uri, agent_synthesis_uri,
agent_pattern_decision_uri,
agent_session_triples, agent_session_triples,
agent_iteration_triples, agent_iteration_triples,
agent_observation_triples, agent_observation_triples,
@ -34,6 +35,7 @@ from trustgraph.provenance import (
agent_plan_triples, agent_plan_triples,
agent_step_result_triples, agent_step_result_triples,
agent_synthesis_triples, agent_synthesis_triples,
agent_pattern_decision_triples,
set_graph, set_graph,
GRAPH_RETRIEVAL, GRAPH_RETRIEVAL,
) )
@ -182,7 +184,7 @@ class PatternBase:
logger.debug(f"Think: {x} (is_final={is_final})") logger.debug(f"Think: {x} (is_final={is_final})")
if streaming: if streaming:
r = AgentResponse( r = AgentResponse(
chunk_type="thought", message_type="thought",
content=x, content=x,
end_of_message=is_final, end_of_message=is_final,
end_of_dialog=False, end_of_dialog=False,
@ -190,7 +192,7 @@ class PatternBase:
) )
else: else:
r = AgentResponse( r = AgentResponse(
chunk_type="thought", message_type="thought",
content=x, content=x,
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -205,7 +207,7 @@ class PatternBase:
logger.debug(f"Observe: {x} (is_final={is_final})") logger.debug(f"Observe: {x} (is_final={is_final})")
if streaming: if streaming:
r = AgentResponse( r = AgentResponse(
chunk_type="observation", message_type="observation",
content=x, content=x,
end_of_message=is_final, end_of_message=is_final,
end_of_dialog=False, end_of_dialog=False,
@ -213,7 +215,7 @@ class PatternBase:
) )
else: else:
r = AgentResponse( r = AgentResponse(
chunk_type="observation", message_type="observation",
content=x, content=x,
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -228,7 +230,7 @@ class PatternBase:
logger.debug(f"Answer: {x}") logger.debug(f"Answer: {x}")
if streaming: if streaming:
r = AgentResponse( r = AgentResponse(
chunk_type="answer", message_type="answer",
content=x, content=x,
end_of_message=False, end_of_message=False,
end_of_dialog=False, end_of_dialog=False,
@ -236,7 +238,7 @@ class PatternBase:
) )
else: else:
r = AgentResponse( r = AgentResponse(
chunk_type="answer", message_type="answer",
content=x, content=x,
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -270,16 +272,43 @@ class PatternBase:
logger.debug(f"Emitted session triples for {session_uri}") logger.debug(f"Emitted session triples for {session_uri}")
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=session_uri, explain_id=session_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples, explain_triples=triples,
)) ))
async def emit_pattern_decision_triples(
self, flow, session_id, session_uri, pattern, task_type,
user, collection, respond,
):
"""Emit provenance triples for a meta-router pattern decision."""
uri = agent_pattern_decision_uri(session_id)
triples = set_graph(
agent_pattern_decision_triples(
uri, session_uri, pattern, task_type,
),
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
metadata=Metadata(id=uri, user=user, collection=collection),
triples=triples,
))
await respond(AgentResponse(
message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples,
))
return uri
async def emit_iteration_triples(self, flow, session_id, iteration_num, async def emit_iteration_triples(self, flow, session_id, iteration_num,
session_uri, act, request, respond, session_uri, act, request, respond,
streaming): streaming, tool_candidates=None,
step_number=None,
llm_duration_ms=None,
in_token=None, out_token=None,
model=None):
"""Emit provenance triples for an iteration (Analysis+ToolUse).""" """Emit provenance triples for an iteration (Analysis+ToolUse)."""
iteration_uri = agent_iteration_uri(session_id, iteration_num) iteration_uri = agent_iteration_uri(session_id, iteration_num)
@ -319,6 +348,12 @@ class PatternBase:
arguments=act.arguments, arguments=act.arguments,
thought_uri=thought_entity_uri if thought_doc_id else None, thought_uri=thought_entity_uri if thought_doc_id else None,
thought_document_id=thought_doc_id, thought_document_id=thought_doc_id,
tool_candidates=tool_candidates,
step_number=step_number,
llm_duration_ms=llm_duration_ms,
in_token=in_token,
out_token=out_token,
model=model,
), ),
GRAPH_RETRIEVAL, GRAPH_RETRIEVAL,
) )
@ -333,7 +368,7 @@ class PatternBase:
logger.debug(f"Emitted iteration triples for {iteration_uri}") logger.debug(f"Emitted iteration triples for {iteration_uri}")
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=iteration_uri, explain_id=iteration_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
@ -342,7 +377,9 @@ class PatternBase:
async def emit_observation_triples(self, flow, session_id, iteration_num, async def emit_observation_triples(self, flow, session_id, iteration_num,
observation_text, request, respond, observation_text, request, respond,
context=None): context=None,
tool_duration_ms=None,
tool_error=None):
"""Emit provenance triples for a standalone Observation entity.""" """Emit provenance triples for a standalone Observation entity."""
iteration_uri = agent_iteration_uri(session_id, iteration_num) iteration_uri = agent_iteration_uri(session_id, iteration_num)
observation_entity_uri = agent_observation_uri(session_id, iteration_num) observation_entity_uri = agent_observation_uri(session_id, iteration_num)
@ -375,6 +412,8 @@ class PatternBase:
observation_entity_uri, observation_entity_uri,
parent_uri, parent_uri,
document_id=observation_doc_id, document_id=observation_doc_id,
tool_duration_ms=tool_duration_ms,
tool_error=tool_error,
), ),
GRAPH_RETRIEVAL, GRAPH_RETRIEVAL,
) )
@ -389,7 +428,7 @@ class PatternBase:
logger.debug(f"Emitted observation triples for {observation_entity_uri}") logger.debug(f"Emitted observation triples for {observation_entity_uri}")
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=observation_entity_uri, explain_id=observation_entity_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
@ -398,7 +437,7 @@ class PatternBase:
async def emit_final_triples(self, flow, session_id, iteration_num, async def emit_final_triples(self, flow, session_id, iteration_num,
session_uri, answer_text, request, respond, session_uri, answer_text, request, respond,
streaming): streaming, termination_reason=None):
"""Emit provenance triples for the final answer and save to librarian.""" """Emit provenance triples for the final answer and save to librarian."""
final_uri = agent_final_uri(session_id) final_uri = agent_final_uri(session_id)
@ -432,6 +471,7 @@ class PatternBase:
question_uri=final_question_uri, question_uri=final_question_uri,
previous_uri=final_previous_uri, previous_uri=final_previous_uri,
document_id=answer_doc_id, document_id=answer_doc_id,
termination_reason=termination_reason,
), ),
GRAPH_RETRIEVAL, GRAPH_RETRIEVAL,
) )
@ -446,7 +486,7 @@ class PatternBase:
logger.debug(f"Emitted final triples for {final_uri}") logger.debug(f"Emitted final triples for {final_uri}")
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=final_uri, explain_id=final_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
@ -470,7 +510,7 @@ class PatternBase:
triples=triples, triples=triples,
)) ))
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", content="", message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples, explain_triples=triples,
)) ))
@ -509,7 +549,7 @@ class PatternBase:
triples=triples, triples=triples,
)) ))
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", content="", message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples, explain_triples=triples,
)) ))
@ -529,7 +569,7 @@ class PatternBase:
triples=triples, triples=triples,
)) ))
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", content="", message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples, explain_triples=triples,
)) ))
@ -562,14 +602,14 @@ class PatternBase:
triples=triples, triples=triples,
)) ))
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", content="", message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples, explain_triples=triples,
)) ))
async def emit_synthesis_triples( async def emit_synthesis_triples(
self, flow, session_id, previous_uris, answer_text, user, collection, self, flow, session_id, previous_uris, answer_text, user, collection,
respond, streaming, respond, streaming, termination_reason=None,
): ):
"""Emit provenance for a synthesis answer.""" """Emit provenance for a synthesis answer."""
uri = agent_synthesis_uri(session_id) uri = agent_synthesis_uri(session_id)
@ -586,7 +626,10 @@ class PatternBase:
doc_id = None doc_id = None
triples = set_graph( triples = set_graph(
agent_synthesis_triples(uri, previous_uris, doc_id), agent_synthesis_triples(
uri, previous_uris, doc_id,
termination_reason=termination_reason,
),
GRAPH_RETRIEVAL, GRAPH_RETRIEVAL,
) )
await flow("explainability").send(Triples( await flow("explainability").send(Triples(
@ -594,7 +637,7 @@ class PatternBase:
triples=triples, triples=triples,
)) ))
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", content="", message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples, explain_triples=triples,
)) ))
@ -616,7 +659,7 @@ class PatternBase:
if text: if text:
accumulated.append(text) accumulated.append(text)
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="answer", message_type="answer",
content=text, content=text,
end_of_message=False, end_of_message=False,
end_of_dialog=False, end_of_dialog=False,
@ -666,7 +709,7 @@ class PatternBase:
# Answer wasn't streamed yet — send it as a chunk first # Answer wasn't streamed yet — send it as a chunk first
if answer_text: if answer_text:
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="answer", message_type="answer",
content=answer_text, content=answer_text,
end_of_message=False, end_of_message=False,
end_of_dialog=False, end_of_dialog=False,
@ -675,7 +718,7 @@ class PatternBase:
if streaming: if streaming:
# End-of-dialog marker with usage # End-of-dialog marker with usage
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="answer", message_type="answer",
content="", content="",
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,
@ -684,7 +727,7 @@ class PatternBase:
)) ))
else: else:
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="answer", message_type="answer",
content=answer_text, content=answer_text,
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,

View file

@ -35,7 +35,8 @@ class PlanThenExecutePattern(PatternBase):
Subsequent calls execute the next pending plan step via ReACT. Subsequent calls execute the next pending plan step via ReACT.
""" """
async def iterate(self, request, respond, next, flow, usage=None): async def iterate(self, request, respond, next, flow, usage=None,
pattern_decision_uri=None):
if usage is None: if usage is None:
usage = UsageTracker() usage = UsageTracker()
@ -66,16 +67,18 @@ class PlanThenExecutePattern(PatternBase):
# Determine current phase by checking history for a plan step # Determine current phase by checking history for a plan step
plan = self._extract_plan(request.history) plan = self._extract_plan(request.history)
derive_from_uri = pattern_decision_uri or session_uri
if plan is None: if plan is None:
await self._planning_iteration( await self._planning_iteration(
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_uri, session_id, collection, streaming, derive_from_uri,
iteration_num, usage=usage, iteration_num, usage=usage,
) )
else: else:
await self._execution_iteration( await self._execution_iteration(
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_uri, session_id, collection, streaming, derive_from_uri,
iteration_num, plan, usage=usage, iteration_num, plan, usage=usage,
) )
@ -385,6 +388,7 @@ class PlanThenExecutePattern(PatternBase):
await self.emit_synthesis_triples( await self.emit_synthesis_triples(
flow, session_id, last_step_uri, flow, session_id, last_step_uri,
response_text, request.user, collection, respond, streaming, response_text, request.user, collection, respond, streaming,
termination_reason="plan-complete",
) )
if self.is_subagent(request): if self.is_subagent(request):

View file

@ -37,7 +37,8 @@ class ReactPattern(PatternBase):
result is appended to history and a next-request is emitted. result is appended to history and a next-request is emitted.
""" """
async def iterate(self, request, respond, next, flow, usage=None): async def iterate(self, request, respond, next, flow, usage=None,
pattern_decision_uri=None):
if usage is None: if usage is None:
usage = UsageTracker() usage = UsageTracker()
@ -108,11 +109,23 @@ class ReactPattern(PatternBase):
session_id, iteration_num, session_id, iteration_num,
) )
# Tool names available to the LLM for this iteration
tool_candidates = [t.name for t in filtered_tools.values()]
# Use pattern decision as derivation source if available
derive_from_uri = pattern_decision_uri or session_uri
# Callback: emit Analysis+ToolUse triples before tool executes # Callback: emit Analysis+ToolUse triples before tool executes
async def on_action(act): async def on_action(act):
await self.emit_iteration_triples( await self.emit_iteration_triples(
flow, session_id, iteration_num, session_uri, flow, session_id, iteration_num, derive_from_uri,
act, request, respond, streaming, act, request, respond, streaming,
tool_candidates=tool_candidates,
step_number=iteration_num,
llm_duration_ms=getattr(act, 'llm_duration_ms', None),
in_token=getattr(act, 'in_token', None),
out_token=getattr(act, 'out_token', None),
model=getattr(act, 'llm_model', None),
) )
act = await temp_agent.react( act = await temp_agent.react(
@ -138,8 +151,9 @@ class ReactPattern(PatternBase):
# Emit final provenance # Emit final provenance
await self.emit_final_triples( await self.emit_final_triples(
flow, session_id, iteration_num, session_uri, flow, session_id, iteration_num, derive_from_uri,
f, request, respond, streaming, f, request, respond, streaming,
termination_reason="final-answer",
) )
if self.is_subagent(request): if self.is_subagent(request):
@ -157,6 +171,8 @@ class ReactPattern(PatternBase):
flow, session_id, iteration_num, flow, session_id, iteration_num,
act.observation, request, respond, act.observation, request, respond,
context=context, context=context,
tool_duration_ms=getattr(act, 'tool_duration_ms', None),
tool_error=getattr(act, 'tool_error', None),
) )
history.append(act) history.append(act)

View file

@ -23,7 +23,7 @@ from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ..orchestrator.pattern_base import UsageTracker from ..orchestrator.pattern_base import UsageTracker, PatternBase
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
@ -537,19 +537,31 @@ class Processor(AgentService):
) )
# Dispatch to the selected pattern # Dispatch to the selected pattern
selected = self.react_pattern
if pattern == "plan-then-execute": if pattern == "plan-then-execute":
await self.plan_pattern.iterate( selected = self.plan_pattern
request, respond, next, flow, usage=usage,
)
elif pattern == "supervisor": elif pattern == "supervisor":
await self.supervisor_pattern.iterate( selected = self.supervisor_pattern
request, respond, next, flow, usage=usage,
) # Emit pattern decision provenance on first iteration
else: pattern_decision_uri = None
# Default to react if not request.history and pattern:
await self.react_pattern.iterate( session_id = getattr(request, 'session_id', '')
request, respond, next, flow, usage=usage, if session_id:
) session_uri = self.provenance_session_uri(session_id)
pattern_decision_uri = \
await selected.emit_pattern_decision_triples(
flow, session_id, session_uri,
pattern, getattr(request, 'task_type', ''),
request.user,
getattr(request, 'collection', 'default'),
respond,
)
await selected.iterate(
request, respond, next, flow, usage=usage,
pattern_decision_uri=pattern_decision_uri,
)
except Exception as e: except Exception as e:
@ -565,7 +577,7 @@ class Processor(AgentService):
) )
r = AgentResponse( r = AgentResponse(
chunk_type="error", message_type="error",
content=str(e), content=str(e),
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,

View file

@ -38,7 +38,8 @@ class SupervisorPattern(PatternBase):
- "synthesise": triggered by aggregator with results in subagent_results - "synthesise": triggered by aggregator with results in subagent_results
""" """
async def iterate(self, request, respond, next, flow, usage=None): async def iterate(self, request, respond, next, flow, usage=None,
pattern_decision_uri=None):
if usage is None: if usage is None:
usage = UsageTracker() usage = UsageTracker()
@ -70,18 +71,20 @@ class SupervisorPattern(PatternBase):
) )
) )
derive_from_uri = pattern_decision_uri or session_uri
if has_results: if has_results:
await self._synthesise( await self._synthesise(
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num, derive_from_uri, iteration_num,
usage=usage, usage=usage,
) )
else: else:
await self._decompose_and_fanout( await self._decompose_and_fanout(
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num, derive_from_uri, iteration_num,
usage=usage, usage=usage,
) )
@ -235,6 +238,7 @@ class SupervisorPattern(PatternBase):
await self.emit_synthesis_triples( await self.emit_synthesis_triples(
flow, session_id, finding_uris, flow, session_id, finding_uris,
response_text, request.user, collection, respond, streaming, response_text, request.user, collection, respond, streaming,
termination_reason="subagents-complete",
) )
await self.send_final_response( await self.send_final_response(

View file

@ -3,6 +3,7 @@ import logging
import json import json
import re import re
import asyncio import asyncio
import time
from . types import Action, Final from . types import Action, Final
@ -260,6 +261,7 @@ class AgentManager:
streaming=True, streaming=True,
chunk_callback=on_chunk chunk_callback=on_chunk
) )
self._last_prompt_result = prompt_result
if usage: if usage:
usage.track(prompt_result) usage.track(prompt_result)
@ -269,7 +271,13 @@ class AgentManager:
# Get result # Get result
result = parser.get_result() result = parser.get_result()
if result is None: if result is None:
raise RuntimeError("Parser failed to produce a result") return Action(
thought="",
name="__parse_error__",
arguments={},
observation="",
tool_error="LLM response could not be parsed (streaming)",
)
return result return result
@ -281,6 +289,7 @@ class AgentManager:
variables=variables, variables=variables,
streaming=False streaming=False
) )
self._last_prompt_result = prompt_result
if usage: if usage:
usage.track(prompt_result) usage.track(prompt_result)
response_text = prompt_result.text response_text = prompt_result.text
@ -294,12 +303,19 @@ class AgentManager:
except ValueError as e: except ValueError as e:
logger.error(f"Failed to parse response: {e}") logger.error(f"Failed to parse response: {e}")
logger.error(f"Response was: {response_text}") logger.error(f"Response was: {response_text}")
raise RuntimeError(f"Failed to parse agent response: {e}") return Action(
thought="",
name="__parse_error__",
arguments={},
observation="",
tool_error=f"LLM parse error: {e}",
)
async def react(self, question, history, think, observe, context, async def react(self, question, history, think, observe, context,
streaming=False, answer=None, on_action=None, streaming=False, answer=None, on_action=None,
usage=None): usage=None):
t0 = time.monotonic()
act = await self.reason( act = await self.reason(
question = question, question = question,
history = history, history = history,
@ -310,6 +326,12 @@ class AgentManager:
answer = answer, answer = answer,
usage = usage, usage = usage,
) )
act.llm_duration_ms = int((time.monotonic() - t0) * 1000)
pr = getattr(self, '_last_prompt_result', None)
if pr:
act.in_token = pr.in_token
act.out_token = pr.out_token
act.llm_model = pr.model
if isinstance(act, Final): if isinstance(act, Final):
@ -328,24 +350,43 @@ class AgentManager:
logger.debug(f"ACTION: {act.name}") logger.debug(f"ACTION: {act.name}")
# Notify caller before tool execution (for provenance)
if on_action:
await on_action(act)
# Handle parse errors — skip tool execution
if act.name == "__parse_error__":
resp = f"Error: {act.tool_error}"
act.tool_duration_ms = 0
await observe(resp, is_final=True)
act.observation = resp
return act
if act.name in self.tools: if act.name in self.tools:
action = self.tools[act.name] action = self.tools[act.name]
else: else:
raise RuntimeError(f"No action for {act.name}!") raise RuntimeError(f"No action for {act.name}!")
# Notify caller before tool execution (for provenance) t0 = time.monotonic()
if on_action: try:
await on_action(act) resp = await action.implementation(context).invoke(
**act.arguments
)
resp = await action.implementation(context).invoke( if isinstance(resp, str):
**act.arguments resp = resp.strip()
) else:
resp = str(resp)
resp = resp.strip()
if isinstance(resp, str): act.tool_error = None
resp = resp.strip()
else: except Exception as e:
resp = str(resp) logger.error(f"Tool execution error ({act.name}): {e}")
resp = resp.strip() resp = f"Error: {e}"
act.tool_error = str(e)
act.tool_duration_ms = int((time.monotonic() - t0) * 1000)
await observe(resp, is_final=True) await observe(resp, is_final=True)

View file

@ -469,7 +469,7 @@ class Processor(AgentService):
# Send explain event for session # Send explain event for session
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=session_uri, explain_id=session_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
@ -492,7 +492,7 @@ class Processor(AgentService):
if streaming: if streaming:
r = AgentResponse( r = AgentResponse(
chunk_type="thought", message_type="thought",
content=x, content=x,
end_of_message=is_final, end_of_message=is_final,
end_of_dialog=False, end_of_dialog=False,
@ -500,7 +500,7 @@ class Processor(AgentService):
) )
else: else:
r = AgentResponse( r = AgentResponse(
chunk_type="thought", message_type="thought",
content=x, content=x,
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -515,7 +515,7 @@ class Processor(AgentService):
if streaming: if streaming:
r = AgentResponse( r = AgentResponse(
chunk_type="observation", message_type="observation",
content=x, content=x,
end_of_message=is_final, end_of_message=is_final,
end_of_dialog=False, end_of_dialog=False,
@ -523,7 +523,7 @@ class Processor(AgentService):
) )
else: else:
r = AgentResponse( r = AgentResponse(
chunk_type="observation", message_type="observation",
content=x, content=x,
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -540,7 +540,7 @@ class Processor(AgentService):
if streaming: if streaming:
r = AgentResponse( r = AgentResponse(
chunk_type="answer", message_type="answer",
content=x, content=x,
end_of_message=False, end_of_message=False,
end_of_dialog=False, end_of_dialog=False,
@ -548,7 +548,7 @@ class Processor(AgentService):
) )
else: else:
r = AgentResponse( r = AgentResponse(
chunk_type="answer", message_type="answer",
content=x, content=x,
end_of_message=True, end_of_message=True,
end_of_dialog=False, end_of_dialog=False,
@ -637,7 +637,7 @@ class Processor(AgentService):
logger.debug(f"Emitted iteration triples for {iter_uri}") logger.debug(f"Emitted iteration triples for {iter_uri}")
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=iter_uri, explain_id=iter_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
@ -715,7 +715,7 @@ class Processor(AgentService):
# Send explain event for conclusion # Send explain event for conclusion
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=final_uri, explain_id=final_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
@ -725,7 +725,7 @@ class Processor(AgentService):
if streaming: if streaming:
# End-of-dialog marker — answer chunks already sent via callback # End-of-dialog marker — answer chunks already sent via callback
r = AgentResponse( r = AgentResponse(
chunk_type="answer", message_type="answer",
content="", content="",
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,
@ -733,7 +733,7 @@ class Processor(AgentService):
) )
else: else:
r = AgentResponse( r = AgentResponse(
chunk_type="answer", message_type="answer",
content=f, content=f,
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,
@ -792,7 +792,7 @@ class Processor(AgentService):
# Send explain event for observation # Send explain event for observation
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=observation_entity_uri, explain_id=observation_entity_uri,
explain_graph=GRAPH_RETRIEVAL, explain_graph=GRAPH_RETRIEVAL,
@ -847,7 +847,7 @@ class Processor(AgentService):
streaming = getattr(request, 'streaming', False) if 'request' in locals() else False streaming = getattr(request, 'streaming', False) if 'request' in locals() else False
r = AgentResponse( r = AgentResponse(
chunk_type="error", message_type="error",
content=str(e), content=str(e),
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,

View file

@ -42,7 +42,7 @@ class KnowledgeQueryImpl:
async def explain_callback(explain_id, explain_graph, explain_triples=None): async def explain_callback(explain_id, explain_graph, explain_triples=None):
self.context.last_sub_explain_uri = explain_id self.context.last_sub_explain_uri = explain_id
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="explain", message_type="explain",
content="", content="",
explain_id=explain_id, explain_id=explain_id,
explain_graph=explain_graph, explain_graph=explain_graph,

View file

@ -22,9 +22,19 @@ class Action:
name : str name : str
arguments : dict arguments : dict
observation : str observation : str
llm_duration_ms : int = None
tool_duration_ms : int = None
tool_error : str = None
in_token : int = None
out_token : int = None
llm_model : str = None
@dataclasses.dataclass @dataclasses.dataclass
class Final: class Final:
thought : str thought : str
final : str final : str
llm_duration_ms : int = None
in_token : int = None
out_token : int = None
llm_model : str = None

View file

@ -56,6 +56,8 @@ class Query:
if not concepts: if not concepts:
concepts = [query] concepts = [query]
self.concepts_usage = result
if self.verbose: if self.verbose:
logger.debug(f"Extracted concepts: {concepts}") logger.debug(f"Extracted concepts: {concepts}")
@ -217,8 +219,14 @@ class DocumentRag:
# Emit grounding explainability after concept extraction # Emit grounding explainability after concept extraction
if explain_callback: if explain_callback:
cu = getattr(q, 'concepts_usage', None)
gnd_triples = set_graph( gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts), grounding_triples(
gnd_uri, q_uri, concepts,
in_token=cu.in_token if cu else None,
out_token=cu.out_token if cu else None,
model=cu.model if cu else None,
),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
await explain_callback(gnd_triples, gnd_uri) await explain_callback(gnd_triples, gnd_uri)
@ -286,6 +294,9 @@ class DocumentRag:
docrag_synthesis_triples( docrag_synthesis_triples(
syn_uri, exp_uri, syn_uri, exp_uri,
document_id=synthesis_doc_id, document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,
model=synthesis_result.model if synthesis_result else None,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )

View file

@ -152,6 +152,8 @@ class Query:
if self.verbose: if self.verbose:
logger.debug(f"Extracted concepts: {concepts}") logger.debug(f"Extracted concepts: {concepts}")
self.concepts_usage = result
# Fall back to raw query if extraction returns nothing # Fall back to raw query if extraction returns nothing
return concepts if concepts else [query] return concepts if concepts else [query]
@ -667,8 +669,14 @@ class GraphRag:
# Emit grounding explain after concept extraction # Emit grounding explain after concept extraction
if explain_callback: if explain_callback:
cu = getattr(q, 'concepts_usage', None)
gnd_triples = set_graph( gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts), grounding_triples(
gnd_uri, q_uri, concepts,
in_token=cu.in_token if cu else None,
out_token=cu.out_token if cu else None,
model=cu.model if cu else None,
),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
await explain_callback(gnd_triples, gnd_uri) await explain_callback(gnd_triples, gnd_uri)
@ -883,9 +891,25 @@ class GraphRag:
# Emit focus explain after edge selection completes # Emit focus explain after edge selection completes
if explain_callback: if explain_callback:
# Sum scoring + reasoning token usage for focus event
focus_in = 0
focus_out = 0
focus_model = None
for r in [scoring_result, reasoning_result]:
if r is not None:
if r.in_token is not None:
focus_in += r.in_token
if r.out_token is not None:
focus_out += r.out_token
if r.model is not None:
focus_model = r.model
foc_triples = set_graph( foc_triples = set_graph(
focus_triples( focus_triples(
foc_uri, exp_uri, selected_edges_with_reasoning, session_id foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
in_token=focus_in or None,
out_token=focus_out or None,
model=focus_model,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )
@ -956,6 +980,9 @@ class GraphRag:
synthesis_triples( synthesis_triples(
syn_uri, foc_uri, syn_uri, foc_uri,
document_id=synthesis_doc_id, document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,
model=synthesis_result.model if synthesis_result else None,
), ),
GRAPH_RETRIEVAL GRAPH_RETRIEVAL
) )