diff --git a/dev-tools/tests/agent_dag/analyse_trace.py b/dev-tools/tests/agent_dag/analyse_trace.py index b71cdebe..42cca118 100644 --- a/dev-tools/tests/agent_dag/analyse_trace.py +++ b/dev-tools/tests/agent_dag/analyse_trace.py @@ -131,21 +131,21 @@ async def analyse(path, url, flow, user, collection): for i, msg in enumerate(messages): resp = msg.get("response", {}) - chunk_type = resp.get("chunk_type", "?") + message_type = resp.get("message_type", "?") - if chunk_type == "explain": + if message_type == "explain": explain_id = resp.get("explain_id", "") explain_ids.append(explain_id) - print(f" {i:3d} {chunk_type} {explain_id}") + print(f" {i:3d} {message_type} {explain_id}") else: - print(f" {i:3d} {chunk_type}") + print(f" {i:3d} {message_type}") # Rule 7: message_id on content chunks - if chunk_type in ("thought", "observation", "answer"): + if message_type in ("thought", "observation", "answer"): mid = resp.get("message_id", "") if not mid: errors.append( - f"[msg {i}] {chunk_type} chunk missing message_id" + f"[msg {i}] {message_type} chunk missing message_id" ) print() diff --git a/specs/ontology/trustgraph.ttl b/specs/ontology/trustgraph.ttl new file mode 100644 index 00000000..4c7de612 --- /dev/null +++ b/specs/ontology/trustgraph.ttl @@ -0,0 +1,415 @@ +@prefix tg: . +@prefix owl: . +@prefix rdf: . +@prefix rdfs: . +@prefix xsd: . +@prefix prov: . + +# ============================================================================= +# Ontology declaration +# ============================================================================= + + + a owl:Ontology ; + rdfs:label "TrustGraph Ontology" ; + rdfs:comment "Vocabulary for TrustGraph provenance, extraction metadata, and explainability." ; + owl:versionInfo "2.3" . + +# ============================================================================= +# Classes — Extraction provenance +# ============================================================================= + +tg:Document a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Document" ; + rdfs:comment "A loaded document (PDF, text, etc.)." . + +tg:Page a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Page" ; + rdfs:comment "A page within a document." . + +tg:Section a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Section" ; + rdfs:comment "A structural section within a document." . + +tg:Chunk a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Chunk" ; + rdfs:comment "A text chunk produced by the chunker." . + +tg:Image a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Image" ; + rdfs:comment "An image extracted from a document." . + +tg:Subgraph a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Subgraph" ; + rdfs:comment "A set of triples extracted from a chunk." . + +# ============================================================================= +# Classes — Query-time explainability (shared) +# ============================================================================= + +tg:Question a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Question" ; + rdfs:comment "Root entity for a query session." . + +tg:GraphRagQuestion a owl:Class ; + rdfs:subClassOf tg:Question ; + rdfs:label "Graph RAG Question" ; + rdfs:comment "A question answered via graph-based RAG." . + +tg:DocRagQuestion a owl:Class ; + rdfs:subClassOf tg:Question ; + rdfs:label "Document RAG Question" ; + rdfs:comment "A question answered via document-based RAG." . + +tg:AgentQuestion a owl:Class ; + rdfs:subClassOf tg:Question ; + rdfs:label "Agent Question" ; + rdfs:comment "A question answered via the agent orchestrator." . + +tg:Grounding a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Grounding" ; + rdfs:comment "Concept extraction step (query decomposition into search terms)." . + +tg:Exploration a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Exploration" ; + rdfs:comment "Entity/chunk retrieval step." . + +tg:Focus a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Focus" ; + rdfs:comment "Edge selection and scoring step (GraphRAG)." . + +tg:Synthesis a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Synthesis" ; + rdfs:comment "Final answer synthesis from retrieved context." . + +# ============================================================================= +# Classes — Agent provenance +# ============================================================================= + +tg:Analysis a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Analysis" ; + rdfs:comment "One agent iteration: reasoning followed by tool selection." . + +tg:ToolUse a owl:Class ; + rdfs:label "ToolUse" ; + rdfs:comment "Mixin type applied to Analysis when a tool is invoked." . + +tg:Error a owl:Class ; + rdfs:label "Error" ; + rdfs:comment "Mixin type applied to events where a failure occurred (tool error, parse error)." . + +tg:Conclusion a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Conclusion" ; + rdfs:comment "Agent final answer (ReAct pattern)." . + +tg:PatternDecision a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Pattern Decision" ; + rdfs:comment "Meta-router decision recording which execution pattern was selected." . + +# --- Unifying types --- + +tg:Answer a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Answer" ; + rdfs:comment "Unifying type for any terminal answer (Synthesis, Conclusion, Finding, StepResult)." . + +tg:Reflection a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Reflection" ; + rdfs:comment "Unifying type for intermediate commentary (Thought, Observation)." . + +tg:Thought a owl:Class ; + rdfs:subClassOf tg:Reflection ; + rdfs:label "Thought" ; + rdfs:comment "Agent reasoning text within an iteration." . + +tg:Observation a owl:Class ; + rdfs:subClassOf tg:Reflection ; + rdfs:label "Observation" ; + rdfs:comment "Tool execution result." . + +# --- Orchestrator types --- + +tg:Decomposition a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Decomposition" ; + rdfs:comment "Supervisor pattern: question decomposed into sub-goals." . + +tg:Finding a owl:Class ; + rdfs:subClassOf tg:Answer ; + rdfs:label "Finding" ; + rdfs:comment "Result from a sub-agent execution." . + +tg:Plan a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Plan" ; + rdfs:comment "Plan-then-execute pattern: structured plan of steps." . + +tg:StepResult a owl:Class ; + rdfs:subClassOf tg:Answer ; + rdfs:label "Step Result" ; + rdfs:comment "Result from executing one plan step." . + +# ============================================================================= +# Properties — Extraction metadata +# ============================================================================= + +tg:contains a owl:ObjectProperty ; + rdfs:label "contains" ; + rdfs:comment "Links a parent entity to a child (e.g. Document contains Page, Subgraph contains triple)." . + +tg:pageCount a owl:DatatypeProperty ; + rdfs:label "page count" ; + rdfs:range xsd:integer ; + rdfs:domain tg:Document . + +tg:mimeType a owl:DatatypeProperty ; + rdfs:label "MIME type" ; + rdfs:range xsd:string ; + rdfs:domain tg:Document . + +tg:pageNumber a owl:DatatypeProperty ; + rdfs:label "page number" ; + rdfs:range xsd:integer ; + rdfs:domain tg:Page . + +tg:chunkIndex a owl:DatatypeProperty ; + rdfs:label "chunk index" ; + rdfs:range xsd:integer ; + rdfs:domain tg:Chunk . + +tg:charOffset a owl:DatatypeProperty ; + rdfs:label "character offset" ; + rdfs:range xsd:integer . + +tg:charLength a owl:DatatypeProperty ; + rdfs:label "character length" ; + rdfs:range xsd:integer . + +tg:chunkSize a owl:DatatypeProperty ; + rdfs:label "chunk size" ; + rdfs:range xsd:integer . + +tg:chunkOverlap a owl:DatatypeProperty ; + rdfs:label "chunk overlap" ; + rdfs:range xsd:integer . + +tg:componentVersion a owl:DatatypeProperty ; + rdfs:label "component version" ; + rdfs:range xsd:string . + +tg:llmModel a owl:DatatypeProperty ; + rdfs:label "LLM model" ; + rdfs:range xsd:string . + +tg:ontology a owl:DatatypeProperty ; + rdfs:label "ontology" ; + rdfs:range xsd:string . + +tg:embeddingModel a owl:DatatypeProperty ; + rdfs:label "embedding model" ; + rdfs:range xsd:string . + +tg:sourceText a owl:DatatypeProperty ; + rdfs:label "source text" ; + rdfs:range xsd:string . + +tg:sourceCharOffset a owl:DatatypeProperty ; + rdfs:label "source character offset" ; + rdfs:range xsd:integer . + +tg:sourceCharLength a owl:DatatypeProperty ; + rdfs:label "source character length" ; + rdfs:range xsd:integer . + +tg:elementTypes a owl:DatatypeProperty ; + rdfs:label "element types" ; + rdfs:range xsd:string . + +tg:tableCount a owl:DatatypeProperty ; + rdfs:label "table count" ; + rdfs:range xsd:integer . + +tg:imageCount a owl:DatatypeProperty ; + rdfs:label "image count" ; + rdfs:range xsd:integer . + +# ============================================================================= +# Properties — Query-time provenance (GraphRAG / DocumentRAG) +# ============================================================================= + +tg:query a owl:DatatypeProperty ; + rdfs:label "query" ; + rdfs:comment "The user's query text." ; + rdfs:range xsd:string ; + rdfs:domain tg:Question . + +tg:concept a owl:DatatypeProperty ; + rdfs:label "concept" ; + rdfs:comment "An extracted concept from the query." ; + rdfs:range xsd:string ; + rdfs:domain tg:Grounding . + +tg:entity a owl:ObjectProperty ; + rdfs:label "entity" ; + rdfs:comment "A seed entity retrieved during exploration." ; + rdfs:domain tg:Exploration . + +tg:edgeCount a owl:DatatypeProperty ; + rdfs:label "edge count" ; + rdfs:comment "Number of edges explored." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Exploration . + +tg:selectedEdge a owl:ObjectProperty ; + rdfs:label "selected edge" ; + rdfs:comment "Link to an edge selection entity within a Focus event." ; + rdfs:domain tg:Focus . + +tg:edge a owl:ObjectProperty ; + rdfs:label "edge" ; + rdfs:comment "A quoted triple representing a knowledge graph edge." ; + rdfs:domain tg:Focus . + +tg:reasoning a owl:DatatypeProperty ; + rdfs:label "reasoning" ; + rdfs:comment "LLM-generated reasoning for an edge selection." ; + rdfs:range xsd:string . + +tg:document a owl:ObjectProperty ; + rdfs:label "document" ; + rdfs:comment "Reference to a document stored in the librarian." . + +tg:chunkCount a owl:DatatypeProperty ; + rdfs:label "chunk count" ; + rdfs:comment "Number of document chunks retrieved (DocumentRAG)." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Exploration . + +tg:selectedChunk a owl:DatatypeProperty ; + rdfs:label "selected chunk" ; + rdfs:comment "A selected chunk ID (DocumentRAG)." ; + rdfs:range xsd:string ; + rdfs:domain tg:Exploration . + +# ============================================================================= +# Properties — Agent provenance +# ============================================================================= + +tg:thought a owl:ObjectProperty ; + rdfs:label "thought" ; + rdfs:comment "Links an Analysis iteration to its Thought sub-entity." ; + rdfs:domain tg:Analysis ; + rdfs:range tg:Thought . + +tg:action a owl:DatatypeProperty ; + rdfs:label "action" ; + rdfs:comment "The tool/action name selected by the agent." ; + rdfs:range xsd:string ; + rdfs:domain tg:Analysis . + +tg:arguments a owl:DatatypeProperty ; + rdfs:label "arguments" ; + rdfs:comment "JSON-encoded arguments passed to the tool." ; + rdfs:range xsd:string ; + rdfs:domain tg:Analysis . + +tg:observation a owl:ObjectProperty ; + rdfs:label "observation" ; + rdfs:comment "Links an Analysis iteration to its Observation sub-entity." ; + rdfs:domain tg:Analysis ; + rdfs:range tg:Observation . + +tg:toolCandidate a owl:DatatypeProperty ; + rdfs:label "tool candidate" ; + rdfs:comment "Name of a tool available to the LLM for this iteration. One triple per candidate." ; + rdfs:range xsd:string ; + rdfs:domain tg:Analysis . + +tg:stepNumber a owl:DatatypeProperty ; + rdfs:label "step number" ; + rdfs:comment "Explicit 1-based step counter for iteration events." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Analysis . + +tg:terminationReason a owl:DatatypeProperty ; + rdfs:label "termination reason" ; + rdfs:comment "Why the agent loop stopped: final-answer, plan-complete, subagents-complete, max-iterations, error." ; + rdfs:range xsd:string . + +tg:pattern a owl:DatatypeProperty ; + rdfs:label "pattern" ; + rdfs:comment "Selected execution pattern (react, plan-then-execute, supervisor)." ; + rdfs:range xsd:string ; + rdfs:domain tg:PatternDecision . + +tg:taskType a owl:DatatypeProperty ; + rdfs:label "task type" ; + rdfs:comment "Identified task type from the meta-router (general, research, etc.)." ; + rdfs:range xsd:string ; + rdfs:domain tg:PatternDecision . + +tg:llmDurationMs a owl:DatatypeProperty ; + rdfs:label "LLM duration (ms)" ; + rdfs:comment "Time spent in the LLM prompt call, in milliseconds." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Analysis . + +tg:toolDurationMs a owl:DatatypeProperty ; + rdfs:label "tool duration (ms)" ; + rdfs:comment "Time spent executing the tool, in milliseconds." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Observation . + +tg:toolError a owl:DatatypeProperty ; + rdfs:label "tool error" ; + rdfs:comment "Error message from a failed tool execution." ; + rdfs:range xsd:string ; + rdfs:domain tg:Observation . + +# --- Token usage predicates (on any event that involves an LLM call) --- + +tg:inToken a owl:DatatypeProperty ; + rdfs:label "input tokens" ; + rdfs:comment "Input token count for the LLM call associated with this event." ; + rdfs:range xsd:integer . + +tg:outToken a owl:DatatypeProperty ; + rdfs:label "output tokens" ; + rdfs:comment "Output token count for the LLM call associated with this event." ; + rdfs:range xsd:integer . + +# --- Orchestrator predicates --- + +tg:subagentGoal a owl:DatatypeProperty ; + rdfs:label "sub-agent goal" ; + rdfs:comment "Goal string assigned to a sub-agent (Decomposition, Finding)." ; + rdfs:range xsd:string . + +tg:planStep a owl:DatatypeProperty ; + rdfs:label "plan step" ; + rdfs:comment "Goal string for a plan step (Plan, StepResult)." ; + rdfs:range xsd:string . + +# ============================================================================= +# Named graphs +# ============================================================================= +# These are not OWL classes but documented here for reference: +# +# (default graph) — Core knowledge facts (extracted triples) +# urn:graph:source — Extraction provenance (document → chunk → triple) +# urn:graph:retrieval — Query-time explainability (question → exploration → synthesis) diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py index 15082437..4fdfe83b 100644 --- a/tests/contract/conftest.py +++ b/tests/contract/conftest.py @@ -87,7 +87,7 @@ def sample_message_data(): "history": [] }, "AgentResponse": { - "chunk_type": "answer", + "message_type": "answer", "content": "Machine learning is a subset of AI.", "end_of_message": True, "end_of_dialog": True, diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index bc5bece1..6b7f82e7 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -212,7 +212,7 @@ class TestAgentMessageContracts: # Test required fields response = AgentResponse(**response_data) - assert hasattr(response, 'chunk_type') + assert hasattr(response, 'message_type') assert hasattr(response, 'content') assert hasattr(response, 'end_of_message') assert hasattr(response, 'end_of_dialog') diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py index 91ce1b77..606061f9 100644 --- a/tests/contract/test_translator_completion_flags.py +++ b/tests/contract/test_translator_completion_flags.py @@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags: # Arrange translator = TranslatorRegistry.get_response_translator("agent") response = AgentResponse( - chunk_type="answer", + message_type="answer", content="4", end_of_message=True, end_of_dialog=True, @@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags: # Arrange translator = TranslatorRegistry.get_response_translator("agent") response = AgentResponse( - chunk_type="thought", + message_type="thought", content="I need to solve this.", end_of_message=True, end_of_dialog=False, @@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags: # Test thought message thought_response = AgentResponse( - chunk_type="thought", + message_type="thought", content="Processing...", end_of_message=True, end_of_dialog=False, @@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags: # Test observation message observation_response = AgentResponse( - chunk_type="observation", + message_type="observation", content="Result found", end_of_message=True, end_of_dialog=False, @@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags: # Streaming format with end_of_dialog=True response = AgentResponse( - chunk_type="answer", + message_type="answer", content="", end_of_message=True, end_of_dialog=True, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7e18f0de..44a9f127 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -418,55 +418,55 @@ def sample_streaming_agent_response(): """Sample streaming agent response chunks""" return [ { - "chunk_type": "thought", + "message_type": "thought", "content": "I need to search", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "thought", + "message_type": "thought", "content": " for information", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "thought", + "message_type": "thought", "content": " about machine learning.", "end_of_message": True, "end_of_dialog": False }, { - "chunk_type": "action", + "message_type": "action", "content": "knowledge_query", "end_of_message": True, "end_of_dialog": False }, { - "chunk_type": "observation", + "message_type": "observation", "content": "Machine learning is", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "observation", + "message_type": "observation", "content": " a subset of AI.", "end_of_message": True, "end_of_dialog": False }, { - "chunk_type": "final-answer", + "message_type": "final-answer", "content": "Machine learning", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "final-answer", + "message_type": "final-answer", "content": " is a subset", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "final-answer", + "message_type": "final-answer", "content": " of artificial intelligence.", "end_of_message": True, "end_of_dialog": True @@ -494,10 +494,10 @@ def streaming_chunk_collector(): """Concatenate all chunk content""" return "".join(self.chunks) - def get_chunk_types(self): + def get_message_types(self): """Get list of chunk types if chunks are dicts""" if self.chunks and isinstance(self.chunks[0], dict): - return [c.get("chunk_type") for c in self.chunks] + return [c.get("message_type") for c in self.chunks] return [] def verify_streaming_protocol(self): diff --git a/tests/integration/test_agent_manager_integration.py b/tests/integration/test_agent_manager_integration.py index a19f4c36..743ab4d2 100644 --- a/tests/integration/test_agent_manager_integration.py +++ b/tests/integration/test_agent_manager_integration.py @@ -327,11 +327,13 @@ Args: { think_callback = AsyncMock() observe_callback = AsyncMock() - # Act & Assert - with pytest.raises(Exception) as exc_info: - await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context) - - assert "Tool execution failed" in str(exc_info.value) + # Act - tool errors are now caught and returned as observations + result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context) + + # Assert - error captured on the action, not raised + assert result.tool_error is not None + assert "Tool execution failed" in result.tool_error + assert "Error:" in result.observation @pytest.mark.asyncio async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context): @@ -538,12 +540,11 @@ Args: { ) if test_case["error_contains"]: - # Should raise an error - with pytest.raises(RuntimeError) as exc_info: - await agent_manager.reason("test question", [], mock_flow_context) - - assert "Failed to parse agent response" in str(exc_info.value) - assert test_case["error_contains"] in str(exc_info.value) + # Parse errors now return an Action with tool_error + result = await agent_manager.reason("test question", [], mock_flow_context) + assert isinstance(result, Action) + assert result.name == "__parse_error__" + assert result.tool_error is not None else: # Should succeed action = await agent_manager.reason("test question", [], mock_flow_context) diff --git a/tests/integration/test_agent_streaming_integration.py b/tests/integration/test_agent_streaming_integration.py index 5c82eb8b..de7372f1 100644 --- a/tests/integration/test_agent_streaming_integration.py +++ b/tests/integration/test_agent_streaming_integration.py @@ -15,7 +15,7 @@ from tests.utils.streaming_assertions import ( assert_agent_streaming_chunks, assert_streaming_chunks_valid, assert_callback_invoked, - assert_chunk_types_valid, + assert_message_types_valid, ) diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index 0b9b283a..bb58e5ee 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -78,10 +78,10 @@ class TestAgentServiceNonStreaming: # Filter out explain events — those are always sent now content_responses = [ - r for r in sent_responses if r.chunk_type != "explain" + r for r in sent_responses if r.message_type != "explain" ] explain_responses = [ - r for r in sent_responses if r.chunk_type == "explain" + r for r in sent_responses if r.message_type == "explain" ] # Should have explain events for session, iteration, observation, and final @@ -93,7 +93,7 @@ class TestAgentServiceNonStreaming: # Check thought message thought_response = content_responses[0] assert isinstance(thought_response, AgentResponse) - assert thought_response.chunk_type == "thought" + assert thought_response.message_type == "thought" assert thought_response.content == "I need to solve this." assert thought_response.end_of_message is True, "Thought message must have end_of_message=True" assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False" @@ -101,7 +101,7 @@ class TestAgentServiceNonStreaming: # Check observation message observation_response = content_responses[1] assert isinstance(observation_response, AgentResponse) - assert observation_response.chunk_type == "observation" + assert observation_response.message_type == "observation" assert observation_response.content == "The answer is 4." assert observation_response.end_of_message is True, "Observation message must have end_of_message=True" assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False" @@ -168,10 +168,10 @@ class TestAgentServiceNonStreaming: # Filter out explain events — those are always sent now content_responses = [ - r for r in sent_responses if r.chunk_type != "explain" + r for r in sent_responses if r.message_type != "explain" ] explain_responses = [ - r for r in sent_responses if r.chunk_type == "explain" + r for r in sent_responses if r.message_type == "explain" ] # Should have explain events for session and final @@ -183,7 +183,7 @@ class TestAgentServiceNonStreaming: # Check final answer message answer_response = content_responses[0] assert isinstance(answer_response, AgentResponse) - assert answer_response.chunk_type == "answer" + assert answer_response.message_type == "answer" assert answer_response.content == "4" assert answer_response.end_of_message is True, "Final answer must have end_of_message=True" assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True" diff --git a/tests/unit/test_agent/test_callback_message_id.py b/tests/unit/test_agent/test_callback_message_id.py index 7cb0ee54..2c4964a5 100644 --- a/tests/unit/test_agent/test_callback_message_id.py +++ b/tests/unit/test_agent/test_callback_message_id.py @@ -29,7 +29,7 @@ class TestThinkCallbackMessageId: assert len(responses) == 1 assert responses[0].message_id == msg_id - assert responses[0].chunk_type == "thought" + assert responses[0].message_type == "thought" @pytest.mark.asyncio async def test_non_streaming_think_has_message_id(self, pattern): @@ -58,7 +58,7 @@ class TestObserveCallbackMessageId: await observe("result", is_final=True) assert responses[0].message_id == msg_id - assert responses[0].chunk_type == "observation" + assert responses[0].message_type == "observation" class TestAnswerCallbackMessageId: @@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId: await answer("the answer") assert responses[0].message_id == msg_id - assert responses[0].chunk_type == "answer" + assert responses[0].message_type == "answer" @pytest.mark.asyncio async def test_no_message_id_default(self, pattern): diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py index 05741cdc..63d87ba1 100644 --- a/tests/unit/test_agent/test_orchestrator_provenance_integration.py +++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py @@ -69,7 +69,7 @@ def collect_explain_events(respond_mock): events = [] for call in respond_mock.call_args_list: resp = call[0][0] - if isinstance(resp, AgentResponse) and resp.chunk_type == "explain": + if isinstance(resp, AgentResponse) and resp.message_type == "explain": events.append({ "explain_id": resp.explain_id, "explain_graph": resp.explain_graph, diff --git a/tests/unit/test_agent/test_parse_chunk_message_id.py b/tests/unit/test_agent/test_parse_chunk_message_id.py index 38942f1e..36d2220e 100644 --- a/tests/unit/test_agent/test_parse_chunk_message_id.py +++ b/tests/unit/test_agent/test_parse_chunk_message_id.py @@ -20,7 +20,7 @@ class TestParseChunkMessageId: def test_thought_message_id(self, client): resp = { - "chunk_type": "thought", + "message_type": "thought", "content": "thinking...", "end_of_message": False, "message_id": "urn:trustgraph:agent:sess/i1/thought", @@ -31,7 +31,7 @@ class TestParseChunkMessageId: def test_observation_message_id(self, client): resp = { - "chunk_type": "observation", + "message_type": "observation", "content": "result", "end_of_message": True, "message_id": "urn:trustgraph:agent:sess/i1/observation", @@ -42,7 +42,7 @@ class TestParseChunkMessageId: def test_answer_message_id(self, client): resp = { - "chunk_type": "answer", + "message_type": "answer", "content": "the answer", "end_of_message": False, "end_of_dialog": False, @@ -54,7 +54,7 @@ class TestParseChunkMessageId: def test_thought_missing_message_id(self, client): resp = { - "chunk_type": "thought", + "message_type": "thought", "content": "thinking...", "end_of_message": False, } @@ -64,7 +64,7 @@ class TestParseChunkMessageId: def test_answer_missing_message_id(self, client): resp = { - "chunk_type": "answer", + "message_type": "answer", "content": "answer", "end_of_message": True, "end_of_dialog": True, diff --git a/tests/unit/test_gateway/test_explain_triples.py b/tests/unit/test_gateway/test_explain_triples.py index 24e77410..42a2f4c5 100644 --- a/tests/unit/test_gateway/test_explain_triples.py +++ b/tests/unit/test_gateway/test_explain_triples.py @@ -158,7 +158,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id="urn:trustgraph:agent:session:abc123", explain_graph="urn:graph:retrieval", @@ -179,7 +179,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="thought", + message_type="thought", content="I need to think...", ) @@ -190,7 +190,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="explain", + message_type="explain", explain_id="urn:trustgraph:agent:session:abc123", explain_triples=sample_triples(), end_of_dialog=False, @@ -203,7 +203,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="answer", + message_type="answer", content="The answer is...", end_of_dialog=True, ) diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py new file mode 100644 index 00000000..184560f0 --- /dev/null +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -0,0 +1,592 @@ +""" +DAG structure tests for provenance chains. + +Verifies that the wasDerivedFrom chain has the expected shape for each +service. These tests catch structural regressions when new entities are +inserted into the chain (e.g. PatternDecision between session and first +iteration). + +Expected chains: + + GraphRAG: question → grounding → exploration → focus → synthesis + DocumentRAG: question → grounding → exploration → synthesis + Agent React: session → pattern-decision → iteration → (observation → iteration)* → final + Agent Plan: session → pattern-decision → plan → step-result(s) → synthesis + Agent Super: session → pattern-decision → decomposition → (fan-out) → finding(s) → synthesis +""" + +import json +import uuid +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.schema import ( + AgentRequest, AgentResponse, AgentStep, PlanStep, + Triple, Term, IRI, LITERAL, +) +from trustgraph.base import PromptResult + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL, + TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, + TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION, + TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION, + TG_OBSERVATION_TYPE, + TG_PATTERN, TG_TASK_TYPE, +) + +from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.retrieval.document_rag.document_rag import DocumentRag + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _collect_events(events): + """Build a dict of explain_id → {types, derived_from, triples}.""" + result = {} + for ev in events: + eid = ev["explain_id"] + triples = ev["triples"] + types = { + t.o.iri for t in triples + if t.s.iri == eid and t.p.iri == RDF_TYPE + } + parents = [ + t.o.iri for t in triples + if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM + ] + result[eid] = { + "types": types, + "derived_from": parents[0] if parents else None, + "triples": triples, + } + return result + + +def _find_by_type(dag, rdf_type): + """Find all event IDs that have the given rdf:type.""" + return [eid for eid, info in dag.items() if rdf_type in info["types"]] + + +def _assert_chain(dag, chain_types): + """Assert that a linear wasDerivedFrom chain exists through the given types.""" + for i in range(1, len(chain_types)): + parent_type = chain_types[i - 1] + child_type = chain_types[i] + parents = _find_by_type(dag, parent_type) + children = _find_by_type(dag, child_type) + assert parents, f"No entity with type {parent_type}" + assert children, f"No entity with type {child_type}" + # At least one child must derive from at least one parent + linked = False + for child_id in children: + derived = dag[child_id]["derived_from"] + if derived in parents: + linked = True + break + assert linked, ( + f"No {child_type} derives from {parent_type}. " + f"Children derive from: " + f"{[dag[c]['derived_from'] for c in children]}" + ) + + +# --------------------------------------------------------------------------- +# GraphRAG DAG structure +# --------------------------------------------------------------------------- + +class TestGraphRagDagStructure: + """Verify: question → grounding → exploration → focus → synthesis""" + + @pytest.fixture + def mock_clients(self): + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + graph_embeddings_client = AsyncMock() + triples_client = AsyncMock() + + embeddings_client.embed.return_value = [[0.1, 0.2]] + graph_embeddings_client.query.return_value = [ + MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")), + ] + triples_client.query_stream.return_value = [ + Triple( + s=Term(type=IRI, iri="http://example.com/e1"), + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="value"), + ) + ] + triples_client.query.return_value = [] + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return PromptResult(response_type="text", text="concept") + elif template_id == "kg-edge-scoring": + edges = variables.get("knowledge", []) + return PromptResult( + response_type="jsonl", + objects=[{"id": e["id"], "score": 10} for e in edges], + ) + elif template_id == "kg-edge-reasoning": + edges = variables.get("knowledge", []) + return PromptResult( + response_type="jsonl", + objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges], + ) + elif template_id == "kg-synthesis": + return PromptResult(response_type="text", text="Answer.") + return PromptResult(response_type="text", text="") + + prompt_client.prompt.side_effect = mock_prompt + return prompt_client, embeddings_client, graph_embeddings_client, triples_client + + @pytest.mark.asyncio + async def test_dag_chain(self, mock_clients): + rag = GraphRag(*mock_clients) + events = [] + + async def explain_cb(triples, explain_id): + events.append({"explain_id": explain_id, "triples": triples}) + + await rag.query( + query="test", explain_callback=explain_cb, edge_score_limit=0, + ) + + dag = _collect_events(events) + assert len(dag) == 5, f"Expected 5 events, got {len(dag)}" + + _assert_chain(dag, [ + TG_GRAPH_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_FOCUS, + TG_SYNTHESIS, + ]) + + +# --------------------------------------------------------------------------- +# DocumentRAG DAG structure +# --------------------------------------------------------------------------- + +class TestDocumentRagDagStructure: + """Verify: question → grounding → exploration → synthesis""" + + @pytest.fixture + def mock_clients(self): + from trustgraph.schema import ChunkMatch + + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + doc_embeddings_client = AsyncMock() + fetch_chunk = AsyncMock(return_value="Chunk content.") + + embeddings_client.embed.return_value = [[0.1, 0.2]] + doc_embeddings_client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.9), + ] + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return PromptResult(response_type="text", text="concept") + return PromptResult(response_type="text", text="") + + prompt_client.prompt.side_effect = mock_prompt + prompt_client.document_prompt.return_value = PromptResult( + response_type="text", text="Answer.", + ) + + return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk + + @pytest.mark.asyncio + async def test_dag_chain(self, mock_clients): + rag = DocumentRag(*mock_clients) + events = [] + + async def explain_cb(triples, explain_id): + events.append({"explain_id": explain_id, "triples": triples}) + + await rag.query( + query="test", explain_callback=explain_cb, + ) + + dag = _collect_events(events) + assert len(dag) == 4, f"Expected 4 events, got {len(dag)}" + + _assert_chain(dag, [ + TG_DOC_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_SYNTHESIS, + ]) + + +# --------------------------------------------------------------------------- +# Agent DAG structure — tested via service.agent_request() +# --------------------------------------------------------------------------- + +def _make_processor(tools=None): + processor = MagicMock() + processor.max_iterations = 10 + processor.save_answer_content = AsyncMock() + + def mock_session_uri(sid): + return f"urn:trustgraph:agent:session:{sid}" + processor.provenance_session_uri.side_effect = mock_session_uri + + agent = MagicMock() + agent.tools = tools or {} + agent.additional_context = "" + processor.agent = agent + processor.aggregator = MagicMock() + + return processor + + +def _make_flow(): + producers = {} + + def factory(name): + if name not in producers: + producers[name] = AsyncMock() + return producers[name] + + flow = MagicMock(side_effect=factory) + return flow + + +def _collect_agent_events(respond_mock): + events = [] + for call in respond_mock.call_args_list: + resp = call[0][0] + if isinstance(resp, AgentResponse) and resp.message_type == "explain": + events.append({ + "explain_id": resp.explain_id, + "triples": resp.explain_triples, + }) + return events + + +class TestAgentReactDagStructure: + """ + Via service.agent_request(), full two-iteration react chain: + session → pattern-decision → iteration(1) → observation(1) → final + + Iteration 1: tool call → observation + Iteration 2: final answer + """ + + def _make_service(self): + from trustgraph.agent.orchestrator.service import Processor + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + mock_tool = MagicMock() + mock_tool.name = "lookup" + mock_tool.description = "Look things up" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="42") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = _make_processor(tools={"lookup": mock_tool}) + + service = Processor.__new__(Processor) + service.max_iterations = 10 + service.save_answer_content = AsyncMock() + service.provenance_session_uri = processor.provenance_session_uri + service.agent = processor.agent + service.aggregator = processor.aggregator + + service.react_pattern = ReactPattern(service) + service.plan_pattern = PlanThenExecutePattern(service) + service.supervisor_pattern = SupervisorPattern(service) + service.meta_router = None + + return service + + @pytest.mark.asyncio + async def test_dag_chain(self): + from trustgraph.agent.react.types import Action, Final + + service = self._make_service() + + respond = AsyncMock() + next_fn = AsyncMock() + flow = _make_flow() + session_id = str(uuid.uuid4()) + + # Iteration 1: tool call → returns Action, triggers on_action + tool exec + action = Action( + thought="I need to look this up", + name="lookup", + arguments={"question": "6x7"}, + observation="", + ) + + with patch( + "trustgraph.agent.orchestrator.react_pattern.AgentManager" + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react_iter1(on_action=None, **kwargs): + if on_action: + await on_action(action) + action.observation = "42" + return action + + mock_am.react.side_effect = mock_react_iter1 + + request1 = AgentRequest( + question="What is 6x7?", + user="testuser", + collection="default", + streaming=False, + session_id=session_id, + pattern="react", + history=[], + ) + + await service.agent_request(request1, respond, next_fn, flow) + + # next_fn should have been called with updated history + assert next_fn.called + + # Iteration 2: final answer + final = Final(thought="The answer is 42", final="42") + next_request = next_fn.call_args[0][0] + + with patch( + "trustgraph.agent.orchestrator.react_pattern.AgentManager" + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react_iter2(**kwargs): + return final + + mock_am.react.side_effect = mock_react_iter2 + + await service.agent_request(next_request, respond, next_fn, flow) + + # Collect and verify DAG + events = _collect_agent_events(respond) + dag = _collect_events(events) + + session_ids = _find_by_type(dag, TG_AGENT_QUESTION) + pd_ids = _find_by_type(dag, TG_PATTERN_DECISION) + analysis_ids = _find_by_type(dag, TG_ANALYSIS) + observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE) + final_ids = _find_by_type(dag, TG_CONCLUSION) + + assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}" + assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}" + assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}" + assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}" + assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}" + + # Full chain: + # session → pattern-decision + assert dag[pd_ids[0]]["derived_from"] == session_ids[0] + + # pattern-decision → iteration(1) + assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0] + + # iteration(1) → observation(1) + assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0] + + # observation(1) → final + assert dag[final_ids[0]]["derived_from"] == observation_ids[0] + + +class TestAgentPlanDagStructure: + """ + Via service.agent_request(): + session → pattern-decision → plan → step-result → synthesis + """ + + @pytest.mark.asyncio + async def test_dag_chain(self): + from trustgraph.agent.orchestrator.service import Processor + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + # Mock tool + mock_tool = MagicMock() + mock_tool.name = "knowledge-query" + mock_tool.description = "Query KB" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="Found it") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = _make_processor(tools={"knowledge-query": mock_tool}) + + service = Processor.__new__(Processor) + service.max_iterations = 10 + service.save_answer_content = AsyncMock() + service.provenance_session_uri = processor.provenance_session_uri + service.agent = processor.agent + service.aggregator = processor.aggregator + + service.react_pattern = ReactPattern(service) + service.plan_pattern = PlanThenExecutePattern(service) + service.supervisor_pattern = SupervisorPattern(service) + service.meta_router = None + + respond = AsyncMock() + next_fn = AsyncMock() + flow = _make_flow() + + # Mock prompt client + mock_prompt_client = AsyncMock() + + call_count = 0 + + async def mock_prompt(id, variables=None, **kwargs): + nonlocal call_count + call_count += 1 + if id == "plan-create": + return PromptResult( + response_type="jsonl", + objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}], + ) + elif id == "plan-step-execute": + return PromptResult( + response_type="json", + object={"tool": "knowledge-query", "arguments": {"question": "test"}}, + ) + elif id == "plan-synthesise": + return PromptResult(response_type="text", text="Final answer.") + return PromptResult(response_type="text", text="") + + mock_prompt_client.prompt.side_effect = mock_prompt + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + session_id = str(uuid.uuid4()) + + # Iteration 1: planning + request1 = AgentRequest( + question="Test?", + user="testuser", + collection="default", + streaming=False, + session_id=session_id, + pattern="plan-then-execute", + history=[], + ) + await service.agent_request(request1, respond, next_fn, flow) + + # Iteration 2: execute step (next_fn was called with updated request) + assert next_fn.called + next_request = next_fn.call_args[0][0] + + # Iteration 3: all steps done → synthesis + # Simulate completed step in history + next_request.history[-1].plan[0].status = "completed" + next_request.history[-1].plan[0].result = "Found it" + + await service.agent_request(next_request, respond, next_fn, flow) + + events = _collect_agent_events(respond) + dag = _collect_events(events) + + session_ids = _find_by_type(dag, TG_AGENT_QUESTION) + pd_ids = _find_by_type(dag, TG_PATTERN_DECISION) + plan_ids = _find_by_type(dag, TG_PLAN_TYPE) + synthesis_ids = _find_by_type(dag, TG_SYNTHESIS) + + assert len(session_ids) == 1 + assert len(pd_ids) == 1 + assert len(plan_ids) == 1 + assert len(synthesis_ids) == 1 + + # Chain: session → pattern-decision → plan → ... → synthesis + assert dag[pd_ids[0]]["derived_from"] == session_ids[0] + assert dag[plan_ids[0]]["derived_from"] == pd_ids[0] + + +class TestAgentSupervisorDagStructure: + """ + Via service.agent_request(): + session → pattern-decision → decomposition → (fan-out) + """ + + @pytest.mark.asyncio + async def test_dag_chain(self): + from trustgraph.agent.orchestrator.service import Processor + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = _make_processor() + + service = Processor.__new__(Processor) + service.max_iterations = 10 + service.save_answer_content = AsyncMock() + service.provenance_session_uri = processor.provenance_session_uri + service.agent = processor.agent + service.aggregator = processor.aggregator + + service.react_pattern = ReactPattern(service) + service.plan_pattern = PlanThenExecutePattern(service) + service.supervisor_pattern = SupervisorPattern(service) + service.meta_router = None + + respond = AsyncMock() + next_fn = AsyncMock() + flow = _make_flow() + + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=["Goal A", "Goal B"], + ) + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = AgentRequest( + question="Research quantum computing", + user="testuser", + collection="default", + streaming=False, + session_id=str(uuid.uuid4()), + pattern="supervisor", + history=[], + ) + + await service.agent_request(request, respond, next_fn, flow) + + events = _collect_agent_events(respond) + dag = _collect_events(events) + + session_ids = _find_by_type(dag, TG_AGENT_QUESTION) + pd_ids = _find_by_type(dag, TG_PATTERN_DECISION) + decomp_ids = _find_by_type(dag, TG_DECOMPOSITION) + + assert len(session_ids) == 1 + assert len(pd_ids) == 1 + assert len(decomp_ids) == 1 + + # Chain: session → pattern-decision → decomposition + assert dag[pd_ids[0]]["derived_from"] == session_ids[0] + assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0] + + # Fan-out should have been called + assert next_fn.call_count == 2 # One per goal diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py index 792db028..f906a00d 100644 --- a/tests/unit/test_provenance/test_triples.py +++ b/tests/unit/test_provenance/test_triples.py @@ -223,7 +223,7 @@ class TestDerivedEntityTriples: assert has_type(triples, self.ENTITY_URI, PROV_ENTITY) assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE) - def test_chunk_entity_has_chunk_type(self): + def test_chunk_entity_has_message_type(self): triples = derived_entity_triples( self.ENTITY_URI, self.PARENT_URI, "chunker", "1.0", diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py index 80443a0c..0b6709fb 100644 --- a/tests/unit/test_python_api_client.py +++ b/tests/unit/test_python_api_client.py @@ -304,14 +304,14 @@ class TestStreamingTypes: assert chunk.content == "thinking..." assert chunk.end_of_message is False - assert chunk.chunk_type == "thought" + assert chunk.message_type == "thought" def test_agent_observation_creation(self): """Test creating AgentObservation chunk""" chunk = AgentObservation(content="observing...", end_of_message=False) assert chunk.content == "observing..." - assert chunk.chunk_type == "observation" + assert chunk.message_type == "observation" def test_agent_answer_creation(self): """Test creating AgentAnswer chunk""" @@ -324,7 +324,7 @@ class TestStreamingTypes: assert chunk.content == "answer" assert chunk.end_of_message is True assert chunk.end_of_dialog is True - assert chunk.chunk_type == "final-answer" + assert chunk.message_type == "final-answer" def test_rag_chunk_creation(self): """Test creating RAGChunk""" diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 985bcbf1..c8c676c9 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -9,7 +9,7 @@ from .streaming_assertions import ( assert_streaming_content_matches, assert_no_empty_chunks, assert_streaming_error_handled, - assert_chunk_types_valid, + assert_message_types_valid, assert_streaming_latency_acceptable, assert_callback_invoked, ) @@ -23,7 +23,7 @@ __all__ = [ "assert_streaming_content_matches", "assert_no_empty_chunks", "assert_streaming_error_handled", - "assert_chunk_types_valid", + "assert_message_types_valid", "assert_streaming_latency_acceptable", "assert_callback_invoked", ] diff --git a/tests/utils/streaming_assertions.py b/tests/utils/streaming_assertions.py index cc9164ed..945bb031 100644 --- a/tests/utils/streaming_assertions.py +++ b/tests/utils/streaming_assertions.py @@ -20,14 +20,14 @@ def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1): assert all(chunk is not None for chunk in chunks), "All chunks should be non-None" -def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"): +def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "message_type"): """ Assert that streaming chunks follow an expected sequence. Args: chunks: List of chunk dictionaries expected_sequence: Expected sequence of chunk types/values - key: Dictionary key to check (default: "chunk_type") + key: Dictionary key to check (default: "message_type") """ actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk] assert actual_sequence == expected_sequence, \ @@ -39,7 +39,7 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]): Assert that agent streaming chunks have valid structure. Validates: - - All chunks have chunk_type field + - All chunks have message_type field - All chunks have content field - All chunks have end_of_message field - All chunks have end_of_dialog field @@ -51,15 +51,15 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]): assert len(chunks) > 0, "Expected at least one chunk" for i, chunk in enumerate(chunks): - assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type" + assert "message_type" in chunk, f"Chunk {i} missing message_type" assert "content" in chunk, f"Chunk {i} missing content" assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message" assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog" - # Validate chunk_type values + # Validate message_type values valid_types = ["thought", "action", "observation", "final-answer"] - assert chunk["chunk_type"] in valid_types, \ - f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}" + assert chunk["message_type"] in valid_types, \ + f"Invalid message_type '{chunk['message_type']}' at index {i}" # Last chunk should signal end of dialog assert chunks[-1]["end_of_dialog"] is True, \ @@ -175,7 +175,7 @@ def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str "Error chunk should have completion flag set to True" -def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"): +def assert_message_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "message_type"): """ Assert that all chunk types are from a valid set. @@ -185,9 +185,9 @@ def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str type_key: Dictionary key for chunk type """ for i, chunk in enumerate(chunks): - chunk_type = chunk.get(type_key) - assert chunk_type in valid_types, \ - f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}" + message_type = chunk.get(type_key) + assert message_type in valid_types, \ + f"Chunk {i} has invalid type '{message_type}', expected one of {valid_types}" def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0): diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index e1007556..6e5064ab 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -178,24 +178,23 @@ class AsyncSocketClient: def _parse_chunk(self, resp: Dict[str, Any]): """Parse response chunk into appropriate type. Returns None for non-content messages.""" - chunk_type = resp.get("chunk_type") message_type = resp.get("message_type") # Handle new GraphRAG message format with message_type if message_type == "provenance": return None - if chunk_type == "thought": + if message_type == "thought": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) ) - elif chunk_type == "observation": + elif message_type == "observation": return AgentObservation( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) ) - elif chunk_type == "answer" or chunk_type == "final-answer": + elif message_type == "answer" or message_type == "final-answer": return AgentAnswer( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), @@ -204,7 +203,7 @@ class AsyncSocketClient: out_token=resp.get("out_token"), model=resp.get("model"), ) - elif chunk_type == "action": + elif message_type == "action": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index fc238e36..c590c9b4 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -360,34 +360,26 @@ class SocketClient: def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]: """Parse response chunk into appropriate type. Returns None for non-content messages.""" - chunk_type = resp.get("chunk_type") message_type = resp.get("message_type") - # Handle GraphRAG/DocRAG message format with message_type if message_type == "explain": if include_provenance: return self._build_provenance_event(resp) return None - # Handle Agent message format with chunk_type="explain" - if chunk_type == "explain": - if include_provenance: - return self._build_provenance_event(resp) - return None - - if chunk_type == "thought": + if message_type == "thought": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), message_id=resp.get("message_id", ""), ) - elif chunk_type == "observation": + elif message_type == "observation": return AgentObservation( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), message_id=resp.get("message_id", ""), ) - elif chunk_type == "answer" or chunk_type == "final-answer": + elif message_type == "answer" or message_type == "final-answer": return AgentAnswer( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), @@ -397,7 +389,7 @@ class SocketClient: out_token=resp.get("out_token"), model=resp.get("model"), ) - elif chunk_type == "action": + elif message_type == "action": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 7b79c962..f5987b0e 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -149,10 +149,10 @@ class AgentThought(StreamingChunk): Attributes: content: Agent's thought text end_of_message: True if this completes the current thought - chunk_type: Always "thought" + message_type: Always "thought" message_id: Provenance URI of the entity being built """ - chunk_type: str = "thought" + message_type: str = "thought" message_id: str = "" @dataclasses.dataclass @@ -166,10 +166,10 @@ class AgentObservation(StreamingChunk): Attributes: content: Observation text describing tool results end_of_message: True if this completes the current observation - chunk_type: Always "observation" + message_type: Always "observation" message_id: Provenance URI of the entity being built """ - chunk_type: str = "observation" + message_type: str = "observation" message_id: str = "" @dataclasses.dataclass @@ -184,9 +184,9 @@ class AgentAnswer(StreamingChunk): content: Answer text end_of_message: True if this completes the current answer segment end_of_dialog: True if this completes the entire agent interaction - chunk_type: Always "final-answer" + message_type: Always "final-answer" """ - chunk_type: str = "final-answer" + message_type: str = "final-answer" end_of_dialog: bool = False message_id: str = "" in_token: Optional[int] = None @@ -208,9 +208,9 @@ class RAGChunk(StreamingChunk): in_token: Input token count (populated on the final chunk, 0 otherwise) out_token: Output token count (populated on the final chunk, 0 otherwise) model: Model identifier (populated on the final chunk, empty otherwise) - chunk_type: Always "rag" + message_type: Always "rag" """ - chunk_type: str = "rag" + message_type: str = "rag" end_of_stream: bool = False error: Optional[Dict[str, str]] = None in_token: Optional[int] = None diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py index d73d03b9..393864fa 100644 --- a/trustgraph-base/trustgraph/base/agent_client.py +++ b/trustgraph-base/trustgraph/base/agent_client.py @@ -30,19 +30,19 @@ class AgentClient(RequestResponse): raise RuntimeError(resp.error.message) # Handle thought chunks - if resp.chunk_type == 'thought': + if resp.message_type == 'thought': if think: await think(resp.content, resp.end_of_message) return False # Continue receiving # Handle observation chunks - if resp.chunk_type == 'observation': + if resp.message_type == 'observation': if observe: await observe(resp.content, resp.end_of_message) return False # Continue receiving # Handle answer chunks - if resp.chunk_type == 'answer': + if resp.message_type == 'answer': if resp.content: accumulated_answer.append(resp.content) if answer_callback: diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py index 1cadbdd5..d17ea37a 100644 --- a/trustgraph-base/trustgraph/clients/agent_client.py +++ b/trustgraph-base/trustgraph/clients/agent_client.py @@ -58,23 +58,23 @@ class AgentClient(BaseClient): def inspect(x): # Handle errors - if x.chunk_type == 'error' or x.error: + if x.message_type == 'error' or x.error: if error_callback: error_callback(x.content or (x.error.message if x.error else "")) # Continue to check end_of_dialog # Handle thought chunks - elif x.chunk_type == 'thought': + elif x.message_type == 'thought': if think: think(x.content, x.end_of_message) # Handle observation chunks - elif x.chunk_type == 'observation': + elif x.message_type == 'observation': if observe: observe(x.content, x.end_of_message) # Handle answer chunks - elif x.chunk_type == 'answer': + elif x.message_type == 'answer': if x.content: accumulated_answer.append(x.content) if answer_callback: diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index b255ea2c..7df59907 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -60,8 +60,8 @@ class AgentResponseTranslator(MessageTranslator): def encode(self, obj: AgentResponse) -> Dict[str, Any]: result = {} - if obj.chunk_type: - result["chunk_type"] = obj.chunk_type + if obj.message_type: + result["message_type"] = obj.message_type if obj.content: result["content"] = obj.content result["end_of_message"] = getattr(obj, "end_of_message", False) diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index e6ce0a9e..051efc66 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -59,6 +59,7 @@ from . uris import ( agent_plan_uri, agent_step_result_uri, agent_synthesis_uri, + agent_pattern_decision_uri, # Document RAG provenance URIs docrag_question_uri, docrag_grounding_uri, @@ -102,6 +103,11 @@ from . namespaces import ( # Agent provenance predicates TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_SUBAGENT_GOAL, TG_PLAN_STEP, + TG_TOOL_CANDIDATE, TG_TERMINATION_REASON, + TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE, + TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR, + TG_IN_TOKEN, TG_OUT_TOKEN, + TG_ERROR_TYPE, # Orchestrator entity types TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, # Document reference predicate @@ -141,6 +147,7 @@ from . agent import ( agent_plan_triples, agent_step_result_triples, agent_synthesis_triples, + agent_pattern_decision_triples, ) # Vocabulary bootstrap @@ -182,6 +189,7 @@ __all__ = [ "agent_plan_uri", "agent_step_result_uri", "agent_synthesis_uri", + "agent_pattern_decision_uri", # Document RAG provenance URIs "docrag_question_uri", "docrag_grounding_uri", @@ -218,6 +226,11 @@ __all__ = [ # Agent provenance predicates "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_SUBAGENT_GOAL", "TG_PLAN_STEP", + "TG_TOOL_CANDIDATE", "TG_TERMINATION_REASON", + "TG_STEP_NUMBER", "TG_PATTERN_DECISION", "TG_PATTERN", "TG_TASK_TYPE", + "TG_LLM_DURATION_MS", "TG_TOOL_DURATION_MS", "TG_TOOL_ERROR", + "TG_IN_TOKEN", "TG_OUT_TOKEN", + "TG_ERROR_TYPE", # Orchestrator entity types "TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT", # Document reference predicate @@ -249,6 +262,7 @@ __all__ = [ "agent_plan_triples", "agent_step_result_triples", "agent_synthesis_triples", + "agent_pattern_decision_triples", # Utility "set_graph", # Vocabulary diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py index 7203174e..5c4f0b2e 100644 --- a/trustgraph-base/trustgraph/provenance/agent.py +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -29,6 +29,11 @@ from . namespaces import ( TG_AGENT_QUESTION, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP, + TG_TOOL_CANDIDATE, TG_TERMINATION_REASON, + TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE, + TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR, + TG_ERROR_TYPE, + TG_IN_TOKEN, TG_OUT_TOKEN, TG_LLM_MODEL, ) @@ -47,6 +52,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple: return Triple(s=_iri(s), p=_iri(p), o=o_term) +def _append_token_triples(triples, uri, in_token=None, out_token=None, + model=None): + """Append in_token/out_token/model triples when values are present.""" + if in_token is not None: + triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token)))) + if out_token is not None: + triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token)))) + if model is not None: + triples.append(_triple(uri, TG_LLM_MODEL, _literal(model))) + + def agent_session_triples( session_uri: str, query: str, @@ -90,6 +106,43 @@ def agent_session_triples( return triples +def agent_pattern_decision_triples( + uri: str, + session_uri: str, + pattern: str, + task_type: str = "", +) -> List[Triple]: + """ + Build triples for a meta-router pattern decision. + + Creates: + - Entity declaration with tg:PatternDecision type + - wasDerivedFrom link to session + - Pattern and task type predicates + + Args: + uri: URI of this decision (from agent_pattern_decision_uri) + session_uri: URI of the parent session + pattern: Selected execution pattern (e.g. "react", "plan-then-execute") + task_type: Identified task type (e.g. "general", "research") + + Returns: + List of Triple objects + """ + triples = [ + _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(uri, RDF_TYPE, _iri(TG_PATTERN_DECISION)), + _triple(uri, RDFS_LABEL, _literal(f"Pattern: {pattern}")), + _triple(uri, TG_PATTERN, _literal(pattern)), + _triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)), + ] + + if task_type: + triples.append(_triple(uri, TG_TASK_TYPE, _literal(task_type))) + + return triples + + def agent_iteration_triples( iteration_uri: str, question_uri: Optional[str] = None, @@ -98,6 +151,12 @@ def agent_iteration_triples( arguments: Dict[str, Any] = None, thought_uri: Optional[str] = None, thought_document_id: Optional[str] = None, + tool_candidates: Optional[List[str]] = None, + step_number: Optional[int] = None, + llm_duration_ms: Optional[int] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for one agent iteration (Analysis+ToolUse). @@ -106,6 +165,7 @@ def agent_iteration_triples( - Entity declaration with tg:Analysis and tg:ToolUse types - wasDerivedFrom link to question (if first iteration) or previous - Action and arguments metadata + - Tool candidates (names of tools visible to the LLM) - Thought sub-entity (tg:Reflection, tg:Thought) with librarian document Args: @@ -116,6 +176,7 @@ def agent_iteration_triples( arguments: Arguments passed to the tool (will be JSON-encoded) thought_uri: URI for the thought sub-entity thought_document_id: Document URI for thought in librarian + tool_candidates: List of tool names available to the LLM Returns: List of Triple objects @@ -132,6 +193,23 @@ def agent_iteration_triples( _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), ] + if tool_candidates: + for name in tool_candidates: + triples.append( + _triple(iteration_uri, TG_TOOL_CANDIDATE, _literal(name)) + ) + + if step_number is not None: + triples.append( + _triple(iteration_uri, TG_STEP_NUMBER, _literal(str(step_number))) + ) + + if llm_duration_ms is not None: + triples.append( + _triple(iteration_uri, TG_LLM_DURATION_MS, + _literal(str(llm_duration_ms))) + ) + if question_uri: triples.append( _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri)) @@ -155,6 +233,8 @@ def agent_iteration_triples( _triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id)) ) + _append_token_triples(triples, iteration_uri, in_token, out_token, model) + return triples @@ -162,6 +242,8 @@ def agent_observation_triples( observation_uri: str, iteration_uri: str, document_id: Optional[str] = None, + tool_duration_ms: Optional[int] = None, + tool_error: Optional[str] = None, ) -> List[Triple]: """ Build triples for an agent observation (standalone entity). @@ -170,11 +252,15 @@ def agent_observation_triples( - Entity declaration with prov:Entity and tg:Observation types - wasDerivedFrom link to the iteration (Analysis+ToolUse) - Document reference to librarian (if provided) + - Tool execution duration (if provided) + - Tool error message (if the tool failed) Args: observation_uri: URI of the observation entity iteration_uri: URI of the iteration this observation derives from document_id: Librarian document ID for the observation content + tool_duration_ms: Tool execution time in milliseconds + tool_error: Error message if the tool failed Returns: List of Triple objects @@ -191,6 +277,20 @@ def agent_observation_triples( _triple(observation_uri, TG_DOCUMENT, _iri(document_id)) ) + if tool_duration_ms is not None: + triples.append( + _triple(observation_uri, TG_TOOL_DURATION_MS, + _literal(str(tool_duration_ms))) + ) + + if tool_error: + triples.append( + _triple(observation_uri, TG_TOOL_ERROR, _literal(tool_error)) + ) + triples.append( + _triple(observation_uri, RDF_TYPE, _iri(TG_ERROR_TYPE)) + ) + return triples @@ -199,6 +299,10 @@ def agent_final_triples( question_uri: Optional[str] = None, previous_uri: Optional[str] = None, document_id: Optional[str] = None, + termination_reason: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for an agent final answer (Conclusion). @@ -208,12 +312,15 @@ def agent_final_triples( - wasGeneratedBy link to question (if no iterations) - wasDerivedFrom link to last iteration (if iterations exist) - Document reference to librarian + - Termination reason (why the agent loop stopped) Args: final_uri: URI of the final answer (from agent_final_uri) question_uri: URI of the question activity (if no iterations) previous_uri: URI of the last iteration (if iterations exist) document_id: Librarian document ID for the answer content + termination_reason: Why the loop stopped, e.g. "final-answer", + "max-iterations", "error" Returns: List of Triple objects @@ -237,6 +344,14 @@ def agent_final_triples( if document_id: triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) + if termination_reason: + triples.append( + _triple(final_uri, TG_TERMINATION_REASON, + _literal(termination_reason)) + ) + + _append_token_triples(triples, final_uri, in_token, out_token, model) + return triples @@ -244,6 +359,9 @@ def agent_decomposition_triples( uri: str, session_uri: str, goals: List[str], + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a supervisor decomposition step.""" triples = [ @@ -255,6 +373,7 @@ def agent_decomposition_triples( ] for goal in goals: triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal))) + _append_token_triples(triples, uri, in_token, out_token, model) return triples @@ -282,6 +401,9 @@ def agent_plan_triples( uri: str, session_uri: str, steps: List[str], + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a plan-then-execute plan.""" triples = [ @@ -293,6 +415,7 @@ def agent_plan_triples( ] for step in steps: triples.append(_triple(uri, TG_PLAN_STEP, _literal(step))) + _append_token_triples(triples, uri, in_token, out_token, model) return triples @@ -301,6 +424,9 @@ def agent_step_result_triples( plan_uri: str, goal: str, document_id: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a plan step result.""" triples = [ @@ -313,6 +439,7 @@ def agent_step_result_triples( ] if document_id: triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) + _append_token_triples(triples, uri, in_token, out_token, model) return triples @@ -320,6 +447,10 @@ def agent_synthesis_triples( uri: str, previous_uris, document_id: Optional[str] = None, + termination_reason: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a synthesis answer. @@ -327,6 +458,8 @@ def agent_synthesis_triples( uri: URI of the synthesis entity previous_uris: Single URI string or list of URIs to derive from document_id: Librarian document ID for the answer content + termination_reason: Why the agent loop stopped + in_token/out_token/model: Token usage for the synthesis LLM call """ triples = [ _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), @@ -342,4 +475,12 @@ def agent_synthesis_triples( if document_id: triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) + + if termination_reason: + triples.append( + _triple(uri, TG_TERMINATION_REASON, _literal(termination_reason)) + ) + + _append_token_triples(triples, uri, in_token, out_token, model) + return triples diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 9e7fbb2d..0b14f1b9 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -119,6 +119,18 @@ TG_ARGUMENTS = TG + "arguments" TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity TG_SUBAGENT_GOAL = TG + "subagentGoal" # Goal string on Decomposition/Finding TG_PLAN_STEP = TG + "planStep" # Step goal string on Plan/StepResult +TG_TOOL_CANDIDATE = TG + "toolCandidate" # Tool name on Analysis events +TG_TERMINATION_REASON = TG + "terminationReason" # Why the agent loop stopped +TG_STEP_NUMBER = TG + "stepNumber" # Explicit step counter on iteration events +TG_PATTERN_DECISION = TG + "PatternDecision" # Meta-router routing decision entity type +TG_PATTERN = TG + "pattern" # Selected execution pattern +TG_TASK_TYPE = TG + "taskType" # Identified task type +TG_LLM_DURATION_MS = TG + "llmDurationMs" # LLM call duration in milliseconds +TG_TOOL_DURATION_MS = TG + "toolDurationMs" # Tool execution duration in milliseconds +TG_TOOL_ERROR = TG + "toolError" # Error message from a failed tool execution +TG_ERROR_TYPE = TG + "Error" # Mixin type for failure events +TG_IN_TOKEN = TG + "inToken" # Input token count for an LLM call +TG_OUT_TOKEN = TG + "outToken" # Output token count for an LLM call # Named graph URIs for RDF datasets # These separate different types of data while keeping them in the same collection diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 920a3482..8bdfc2cb 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -34,6 +34,8 @@ from . namespaces import ( TG_ANSWER_TYPE, # Question subtypes TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, + # Token usage + TG_IN_TOKEN, TG_OUT_TOKEN, ) from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri @@ -74,6 +76,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple: return Triple(s=_iri(s), p=_iri(p), o=o_term) +def _append_token_triples(triples, uri, in_token=None, out_token=None, + model=None): + """Append in_token/out_token/model triples when values are present.""" + if in_token is not None: + triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token)))) + if out_token is not None: + triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token)))) + if model is not None: + triples.append(_triple(uri, TG_LLM_MODEL, _literal(model))) + + def document_triples( doc_uri: str, title: Optional[str] = None, @@ -396,6 +409,9 @@ def grounding_triples( grounding_uri: str, question_uri: str, concepts: List[str], + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a grounding entity (concept decomposition of query). @@ -423,6 +439,8 @@ def grounding_triples( for concept in concepts: triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept))) + _append_token_triples(triples, grounding_uri, in_token, out_token, model) + return triples @@ -485,6 +503,9 @@ def focus_triples( exploration_uri: str, selected_edges_with_reasoning: List[dict], session_id: str = "", + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a focus entity (selected edges with reasoning). @@ -543,6 +564,8 @@ def focus_triples( _triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) ) + _append_token_triples(triples, focus_uri, in_token, out_token, model) + return triples @@ -550,6 +573,9 @@ def synthesis_triples( synthesis_uri: str, focus_uri: str, document_id: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a synthesis entity (final answer). @@ -578,6 +604,8 @@ def synthesis_triples( if document_id: triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) + _append_token_triples(triples, synthesis_uri, in_token, out_token, model) + return triples @@ -674,6 +702,9 @@ def docrag_synthesis_triples( synthesis_uri: str, exploration_uri: str, document_id: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a document RAG synthesis entity (final answer). @@ -702,4 +733,6 @@ def docrag_synthesis_triples( if document_id: triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) + _append_token_triples(triples, synthesis_uri, in_token, out_token, model) + return triples diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index a3aadef6..a26ac867 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -259,6 +259,11 @@ def agent_synthesis_uri(session_id: str) -> str: return f"urn:trustgraph:agent:{session_id}/synthesis" +def agent_pattern_decision_uri(session_id: str) -> str: + """Generate URI for a meta-router pattern decision.""" + return f"urn:trustgraph:agent:{session_id}/pattern-decision" + + # Document RAG provenance URIs # These URIs use the urn:trustgraph:docrag: namespace to distinguish # document RAG provenance from graph RAG provenance diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index 3b3a6d01..cd4a2b45 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -51,8 +51,8 @@ class AgentRequest: @dataclass class AgentResponse: # Streaming-first design - chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error" - content: str = "" # The actual content (interpretation depends on chunk_type) + message_type: str = "" # "thought", "action", "observation", "answer", "explain", "error" + content: str = "" # The actual content (interpretation depends on message_type) end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete end_of_dialog: bool = False # Entire agent dialog is complete diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index ddaef4ca..b379c2df 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -126,7 +126,7 @@ def question_explainable( try: # Track last chunk type for formatting - last_chunk_type = None + last_message_type = None current_outputter = None # Stream agent with explainability - process events as they arrive @@ -138,7 +138,7 @@ def question_explainable( group=group, ): if isinstance(item, AgentThought): - if last_chunk_type != "thought": + if last_message_type != "thought": if current_outputter: current_outputter.__exit__(None, None, None) current_outputter = None @@ -146,7 +146,7 @@ def question_explainable( if verbose: current_outputter = Outputter(width=78, prefix="\U0001f914 ") current_outputter.__enter__() - last_chunk_type = "thought" + last_message_type = "thought" if current_outputter: current_outputter.output(item.content) if current_outputter.word_buffer: @@ -155,7 +155,7 @@ def question_explainable( current_outputter.word_buffer = "" elif isinstance(item, AgentObservation): - if last_chunk_type != "observation": + if last_message_type != "observation": if current_outputter: current_outputter.__exit__(None, None, None) current_outputter = None @@ -163,7 +163,7 @@ def question_explainable( if verbose: current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") current_outputter.__enter__() - last_chunk_type = "observation" + last_message_type = "observation" if current_outputter: current_outputter.output(item.content) if current_outputter.word_buffer: @@ -172,12 +172,12 @@ def question_explainable( current_outputter.word_buffer = "" elif isinstance(item, AgentAnswer): - if last_chunk_type != "answer": + if last_message_type != "answer": if current_outputter: current_outputter.__exit__(None, None, None) current_outputter = None print() - last_chunk_type = "answer" + last_message_type = "answer" # Print answer content directly print(item.content, end="", flush=True) @@ -261,7 +261,7 @@ def question_explainable( current_outputter = None # Final newline if we ended with answer - if last_chunk_type == "answer": + if last_message_type == "answer": print() finally: @@ -322,16 +322,16 @@ def question( # Handle streaming response if streaming: # Track last chunk type and current outputter for streaming - last_chunk_type = None + last_message_type = None current_outputter = None last_answer_chunk = None for chunk in response: - chunk_type = chunk.chunk_type + message_type = chunk.message_type content = chunk.content # Check if we're switching to a new message type - if last_chunk_type != chunk_type: + if last_message_type != message_type: # Close previous outputter if exists if current_outputter: current_outputter.__exit__(None, None, None) @@ -339,15 +339,15 @@ def question( print() # Blank line between message types # Create new outputter for new message type - if chunk_type == "thought" and verbose: + if message_type == "thought" and verbose: current_outputter = Outputter(width=78, prefix="\U0001f914 ") current_outputter.__enter__() - elif chunk_type == "observation" and verbose: + elif message_type == "observation" and verbose: current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") current_outputter.__enter__() # For answer, don't use Outputter - just print as-is - last_chunk_type = chunk_type + last_message_type = message_type # Output the chunk if current_outputter: @@ -357,7 +357,7 @@ def question( print(current_outputter.word_buffer, end="", flush=True) current_outputter.column += len(current_outputter.word_buffer) current_outputter.word_buffer = "" - elif chunk_type == "final-answer": + elif message_type == "final-answer": print(content, end="", flush=True) last_answer_chunk = chunk @@ -366,7 +366,7 @@ def question( current_outputter.__exit__(None, None, None) current_outputter = None # Add final newline if we were outputting answer - elif last_chunk_type == "final-answer": + elif last_message_type == "final-answer": print() if show_usage and last_answer_chunk: @@ -382,17 +382,17 @@ def question( # so we iterate through the chunks (which are complete messages, not text chunks) for chunk in response: # Display thoughts if verbose - if chunk.chunk_type == "thought" and verbose: + if chunk.message_type == "thought" and verbose: output(wrap(chunk.content), "\U0001f914 ") print() # Display observations if verbose - elif chunk.chunk_type == "observation" and verbose: + elif chunk.message_type == "observation" and verbose: output(wrap(chunk.content), "\U0001f4a1 ") print() # Display answer - elif chunk.chunk_type == "final-answer" or chunk.chunk_type == "answer": + elif chunk.message_type == "final-answer" or chunk.message_type == "answer": print(chunk.content) finally: diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index 689d57e6..88d4ee72 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -25,6 +25,7 @@ from trustgraph.provenance import ( agent_plan_uri, agent_step_result_uri, agent_synthesis_uri, + agent_pattern_decision_uri, agent_session_triples, agent_iteration_triples, agent_observation_triples, @@ -34,6 +35,7 @@ from trustgraph.provenance import ( agent_plan_triples, agent_step_result_triples, agent_synthesis_triples, + agent_pattern_decision_triples, set_graph, GRAPH_RETRIEVAL, ) @@ -182,7 +184,7 @@ class PatternBase: logger.debug(f"Think: {x} (is_final={is_final})") if streaming: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=is_final, end_of_dialog=False, @@ -190,7 +192,7 @@ class PatternBase: ) else: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=True, end_of_dialog=False, @@ -205,7 +207,7 @@ class PatternBase: logger.debug(f"Observe: {x} (is_final={is_final})") if streaming: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=is_final, end_of_dialog=False, @@ -213,7 +215,7 @@ class PatternBase: ) else: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=True, end_of_dialog=False, @@ -228,7 +230,7 @@ class PatternBase: logger.debug(f"Answer: {x}") if streaming: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=False, end_of_dialog=False, @@ -236,7 +238,7 @@ class PatternBase: ) else: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=True, end_of_dialog=False, @@ -270,16 +272,43 @@ class PatternBase: logger.debug(f"Emitted session triples for {session_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=session_uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) + async def emit_pattern_decision_triples( + self, flow, session_id, session_uri, pattern, task_type, + user, collection, respond, + ): + """Emit provenance triples for a meta-router pattern decision.""" + uri = agent_pattern_decision_uri(session_id) + triples = set_graph( + agent_pattern_decision_triples( + uri, session_uri, pattern, task_type, + ), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata(id=uri, user=user, collection=collection), + triples=triples, + )) + await respond(AgentResponse( + message_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, + )) + return uri + async def emit_iteration_triples(self, flow, session_id, iteration_num, session_uri, act, request, respond, - streaming): + streaming, tool_candidates=None, + step_number=None, + llm_duration_ms=None, + in_token=None, out_token=None, + model=None): """Emit provenance triples for an iteration (Analysis+ToolUse).""" iteration_uri = agent_iteration_uri(session_id, iteration_num) @@ -319,6 +348,12 @@ class PatternBase: arguments=act.arguments, thought_uri=thought_entity_uri if thought_doc_id else None, thought_document_id=thought_doc_id, + tool_candidates=tool_candidates, + step_number=step_number, + llm_duration_ms=llm_duration_ms, + in_token=in_token, + out_token=out_token, + model=model, ), GRAPH_RETRIEVAL, ) @@ -333,7 +368,7 @@ class PatternBase: logger.debug(f"Emitted iteration triples for {iteration_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=iteration_uri, explain_graph=GRAPH_RETRIEVAL, @@ -342,7 +377,9 @@ class PatternBase: async def emit_observation_triples(self, flow, session_id, iteration_num, observation_text, request, respond, - context=None): + context=None, + tool_duration_ms=None, + tool_error=None): """Emit provenance triples for a standalone Observation entity.""" iteration_uri = agent_iteration_uri(session_id, iteration_num) observation_entity_uri = agent_observation_uri(session_id, iteration_num) @@ -375,6 +412,8 @@ class PatternBase: observation_entity_uri, parent_uri, document_id=observation_doc_id, + tool_duration_ms=tool_duration_ms, + tool_error=tool_error, ), GRAPH_RETRIEVAL, ) @@ -389,7 +428,7 @@ class PatternBase: logger.debug(f"Emitted observation triples for {observation_entity_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=observation_entity_uri, explain_graph=GRAPH_RETRIEVAL, @@ -398,7 +437,7 @@ class PatternBase: async def emit_final_triples(self, flow, session_id, iteration_num, session_uri, answer_text, request, respond, - streaming): + streaming, termination_reason=None): """Emit provenance triples for the final answer and save to librarian.""" final_uri = agent_final_uri(session_id) @@ -432,6 +471,7 @@ class PatternBase: question_uri=final_question_uri, previous_uri=final_previous_uri, document_id=answer_doc_id, + termination_reason=termination_reason, ), GRAPH_RETRIEVAL, ) @@ -446,7 +486,7 @@ class PatternBase: logger.debug(f"Emitted final triples for {final_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=final_uri, explain_graph=GRAPH_RETRIEVAL, @@ -470,7 +510,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -509,7 +549,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -529,7 +569,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -562,14 +602,14 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) async def emit_synthesis_triples( self, flow, session_id, previous_uris, answer_text, user, collection, - respond, streaming, + respond, streaming, termination_reason=None, ): """Emit provenance for a synthesis answer.""" uri = agent_synthesis_uri(session_id) @@ -586,7 +626,10 @@ class PatternBase: doc_id = None triples = set_graph( - agent_synthesis_triples(uri, previous_uris, doc_id), + agent_synthesis_triples( + uri, previous_uris, doc_id, + termination_reason=termination_reason, + ), GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( @@ -594,7 +637,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -616,7 +659,7 @@ class PatternBase: if text: accumulated.append(text) await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content=text, end_of_message=False, end_of_dialog=False, @@ -666,7 +709,7 @@ class PatternBase: # Answer wasn't streamed yet — send it as a chunk first if answer_text: await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content=answer_text, end_of_message=False, end_of_dialog=False, @@ -675,7 +718,7 @@ class PatternBase: if streaming: # End-of-dialog marker with usage await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content="", end_of_message=True, end_of_dialog=True, @@ -684,7 +727,7 @@ class PatternBase: )) else: await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content=answer_text, end_of_message=True, end_of_dialog=True, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 8f5cdcdf..1de31a92 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -35,7 +35,8 @@ class PlanThenExecutePattern(PatternBase): Subsequent calls execute the next pending plan step via ReACT. """ - async def iterate(self, request, respond, next, flow, usage=None): + async def iterate(self, request, respond, next, flow, usage=None, + pattern_decision_uri=None): if usage is None: usage = UsageTracker() @@ -66,16 +67,18 @@ class PlanThenExecutePattern(PatternBase): # Determine current phase by checking history for a plan step plan = self._extract_plan(request.history) + derive_from_uri = pattern_decision_uri or session_uri + if plan is None: await self._planning_iteration( request, respond, next, flow, - session_id, collection, streaming, session_uri, + session_id, collection, streaming, derive_from_uri, iteration_num, usage=usage, ) else: await self._execution_iteration( request, respond, next, flow, - session_id, collection, streaming, session_uri, + session_id, collection, streaming, derive_from_uri, iteration_num, plan, usage=usage, ) @@ -385,6 +388,7 @@ class PlanThenExecutePattern(PatternBase): await self.emit_synthesis_triples( flow, session_id, last_step_uri, response_text, request.user, collection, respond, streaming, + termination_reason="plan-complete", ) if self.is_subagent(request): diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index 777f99c5..25264c26 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -37,7 +37,8 @@ class ReactPattern(PatternBase): result is appended to history and a next-request is emitted. """ - async def iterate(self, request, respond, next, flow, usage=None): + async def iterate(self, request, respond, next, flow, usage=None, + pattern_decision_uri=None): if usage is None: usage = UsageTracker() @@ -108,11 +109,23 @@ class ReactPattern(PatternBase): session_id, iteration_num, ) + # Tool names available to the LLM for this iteration + tool_candidates = [t.name for t in filtered_tools.values()] + + # Use pattern decision as derivation source if available + derive_from_uri = pattern_decision_uri or session_uri + # Callback: emit Analysis+ToolUse triples before tool executes async def on_action(act): await self.emit_iteration_triples( - flow, session_id, iteration_num, session_uri, + flow, session_id, iteration_num, derive_from_uri, act, request, respond, streaming, + tool_candidates=tool_candidates, + step_number=iteration_num, + llm_duration_ms=getattr(act, 'llm_duration_ms', None), + in_token=getattr(act, 'in_token', None), + out_token=getattr(act, 'out_token', None), + model=getattr(act, 'llm_model', None), ) act = await temp_agent.react( @@ -138,8 +151,9 @@ class ReactPattern(PatternBase): # Emit final provenance await self.emit_final_triples( - flow, session_id, iteration_num, session_uri, + flow, session_id, iteration_num, derive_from_uri, f, request, respond, streaming, + termination_reason="final-answer", ) if self.is_subagent(request): @@ -157,6 +171,8 @@ class ReactPattern(PatternBase): flow, session_id, iteration_num, act.observation, request, respond, context=context, + tool_duration_ms=getattr(act, 'tool_duration_ms', None), + tool_error=getattr(act, 'tool_error', None), ) history.append(act) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 9a3584da..3d08154d 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -23,7 +23,7 @@ from ... base import Consumer, Producer from ... base import ConsumerMetrics, ProducerMetrics from ... schema import AgentRequest, AgentResponse, AgentStep, Error -from ..orchestrator.pattern_base import UsageTracker +from ..orchestrator.pattern_base import UsageTracker, PatternBase from ... schema import Triples, Metadata from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue @@ -537,19 +537,31 @@ class Processor(AgentService): ) # Dispatch to the selected pattern + selected = self.react_pattern if pattern == "plan-then-execute": - await self.plan_pattern.iterate( - request, respond, next, flow, usage=usage, - ) + selected = self.plan_pattern elif pattern == "supervisor": - await self.supervisor_pattern.iterate( - request, respond, next, flow, usage=usage, - ) - else: - # Default to react - await self.react_pattern.iterate( - request, respond, next, flow, usage=usage, - ) + selected = self.supervisor_pattern + + # Emit pattern decision provenance on first iteration + pattern_decision_uri = None + if not request.history and pattern: + session_id = getattr(request, 'session_id', '') + if session_id: + session_uri = self.provenance_session_uri(session_id) + pattern_decision_uri = \ + await selected.emit_pattern_decision_triples( + flow, session_id, session_uri, + pattern, getattr(request, 'task_type', ''), + request.user, + getattr(request, 'collection', 'default'), + respond, + ) + + await selected.iterate( + request, respond, next, flow, usage=usage, + pattern_decision_uri=pattern_decision_uri, + ) except Exception as e: @@ -565,7 +577,7 @@ class Processor(AgentService): ) r = AgentResponse( - chunk_type="error", + message_type="error", content=str(e), end_of_message=True, end_of_dialog=True, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index 4b62e767..973a9966 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -38,7 +38,8 @@ class SupervisorPattern(PatternBase): - "synthesise": triggered by aggregator with results in subagent_results """ - async def iterate(self, request, respond, next, flow, usage=None): + async def iterate(self, request, respond, next, flow, usage=None, + pattern_decision_uri=None): if usage is None: usage = UsageTracker() @@ -70,18 +71,20 @@ class SupervisorPattern(PatternBase): ) ) + derive_from_uri = pattern_decision_uri or session_uri + if has_results: await self._synthesise( request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, + derive_from_uri, iteration_num, usage=usage, ) else: await self._decompose_and_fanout( request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, + derive_from_uri, iteration_num, usage=usage, ) @@ -235,6 +238,7 @@ class SupervisorPattern(PatternBase): await self.emit_synthesis_triples( flow, session_id, finding_uris, response_text, request.user, collection, respond, streaming, + termination_reason="subagents-complete", ) await self.send_final_response( diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index 73686f21..82a8f905 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -3,6 +3,7 @@ import logging import json import re import asyncio +import time from . types import Action, Final @@ -260,6 +261,7 @@ class AgentManager: streaming=True, chunk_callback=on_chunk ) + self._last_prompt_result = prompt_result if usage: usage.track(prompt_result) @@ -269,7 +271,13 @@ class AgentManager: # Get result result = parser.get_result() if result is None: - raise RuntimeError("Parser failed to produce a result") + return Action( + thought="", + name="__parse_error__", + arguments={}, + observation="", + tool_error="LLM response could not be parsed (streaming)", + ) return result @@ -281,6 +289,7 @@ class AgentManager: variables=variables, streaming=False ) + self._last_prompt_result = prompt_result if usage: usage.track(prompt_result) response_text = prompt_result.text @@ -294,12 +303,19 @@ class AgentManager: except ValueError as e: logger.error(f"Failed to parse response: {e}") logger.error(f"Response was: {response_text}") - raise RuntimeError(f"Failed to parse agent response: {e}") + return Action( + thought="", + name="__parse_error__", + arguments={}, + observation="", + tool_error=f"LLM parse error: {e}", + ) async def react(self, question, history, think, observe, context, streaming=False, answer=None, on_action=None, usage=None): + t0 = time.monotonic() act = await self.reason( question = question, history = history, @@ -310,6 +326,12 @@ class AgentManager: answer = answer, usage = usage, ) + act.llm_duration_ms = int((time.monotonic() - t0) * 1000) + pr = getattr(self, '_last_prompt_result', None) + if pr: + act.in_token = pr.in_token + act.out_token = pr.out_token + act.llm_model = pr.model if isinstance(act, Final): @@ -328,24 +350,43 @@ class AgentManager: logger.debug(f"ACTION: {act.name}") + # Notify caller before tool execution (for provenance) + if on_action: + await on_action(act) + + # Handle parse errors — skip tool execution + if act.name == "__parse_error__": + resp = f"Error: {act.tool_error}" + act.tool_duration_ms = 0 + await observe(resp, is_final=True) + act.observation = resp + return act + if act.name in self.tools: action = self.tools[act.name] else: raise RuntimeError(f"No action for {act.name}!") - # Notify caller before tool execution (for provenance) - if on_action: - await on_action(act) + t0 = time.monotonic() + try: + resp = await action.implementation(context).invoke( + **act.arguments + ) - resp = await action.implementation(context).invoke( - **act.arguments - ) + if isinstance(resp, str): + resp = resp.strip() + else: + resp = str(resp) + resp = resp.strip() - if isinstance(resp, str): - resp = resp.strip() - else: - resp = str(resp) - resp = resp.strip() + act.tool_error = None + + except Exception as e: + logger.error(f"Tool execution error ({act.name}): {e}") + resp = f"Error: {e}" + act.tool_error = str(e) + + act.tool_duration_ms = int((time.monotonic() - t0) * 1000) await observe(resp, is_final=True) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 2c7423d8..00432181 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -469,7 +469,7 @@ class Processor(AgentService): # Send explain event for session await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=session_uri, explain_graph=GRAPH_RETRIEVAL, @@ -492,7 +492,7 @@ class Processor(AgentService): if streaming: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=is_final, end_of_dialog=False, @@ -500,7 +500,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=True, end_of_dialog=False, @@ -515,7 +515,7 @@ class Processor(AgentService): if streaming: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=is_final, end_of_dialog=False, @@ -523,7 +523,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=True, end_of_dialog=False, @@ -540,7 +540,7 @@ class Processor(AgentService): if streaming: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=False, end_of_dialog=False, @@ -548,7 +548,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=True, end_of_dialog=False, @@ -637,7 +637,7 @@ class Processor(AgentService): logger.debug(f"Emitted iteration triples for {iter_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=iter_uri, explain_graph=GRAPH_RETRIEVAL, @@ -715,7 +715,7 @@ class Processor(AgentService): # Send explain event for conclusion await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=final_uri, explain_graph=GRAPH_RETRIEVAL, @@ -725,7 +725,7 @@ class Processor(AgentService): if streaming: # End-of-dialog marker — answer chunks already sent via callback r = AgentResponse( - chunk_type="answer", + message_type="answer", content="", end_of_message=True, end_of_dialog=True, @@ -733,7 +733,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=f, end_of_message=True, end_of_dialog=True, @@ -792,7 +792,7 @@ class Processor(AgentService): # Send explain event for observation await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=observation_entity_uri, explain_graph=GRAPH_RETRIEVAL, @@ -847,7 +847,7 @@ class Processor(AgentService): streaming = getattr(request, 'streaming', False) if 'request' in locals() else False r = AgentResponse( - chunk_type="error", + message_type="error", content=str(e), end_of_message=True, end_of_dialog=True, diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index c474f740..6674c999 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -42,7 +42,7 @@ class KnowledgeQueryImpl: async def explain_callback(explain_id, explain_graph, explain_triples=None): self.context.last_sub_explain_uri = explain_id await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=explain_id, explain_graph=explain_graph, diff --git a/trustgraph-flow/trustgraph/agent/react/types.py b/trustgraph-flow/trustgraph/agent/react/types.py index 7180db3e..ee0a677f 100644 --- a/trustgraph-flow/trustgraph/agent/react/types.py +++ b/trustgraph-flow/trustgraph/agent/react/types.py @@ -22,9 +22,19 @@ class Action: name : str arguments : dict observation : str - + llm_duration_ms : int = None + tool_duration_ms : int = None + tool_error : str = None + in_token : int = None + out_token : int = None + llm_model : str = None + @dataclasses.dataclass class Final: thought : str final : str + llm_duration_ms : int = None + in_token : int = None + out_token : int = None + llm_model : str = None diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index a2480862..625b1386 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -56,6 +56,8 @@ class Query: if not concepts: concepts = [query] + self.concepts_usage = result + if self.verbose: logger.debug(f"Extracted concepts: {concepts}") @@ -217,8 +219,14 @@ class DocumentRag: # Emit grounding explainability after concept extraction if explain_callback: + cu = getattr(q, 'concepts_usage', None) gnd_triples = set_graph( - grounding_triples(gnd_uri, q_uri, concepts), + grounding_triples( + gnd_uri, q_uri, concepts, + in_token=cu.in_token if cu else None, + out_token=cu.out_token if cu else None, + model=cu.model if cu else None, + ), GRAPH_RETRIEVAL ) await explain_callback(gnd_triples, gnd_uri) @@ -286,6 +294,9 @@ class DocumentRag: docrag_synthesis_triples( syn_uri, exp_uri, document_id=synthesis_doc_id, + in_token=synthesis_result.in_token if synthesis_result else None, + out_token=synthesis_result.out_token if synthesis_result else None, + model=synthesis_result.model if synthesis_result else None, ), GRAPH_RETRIEVAL ) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 07654c64..cf9f5c4e 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -152,6 +152,8 @@ class Query: if self.verbose: logger.debug(f"Extracted concepts: {concepts}") + self.concepts_usage = result + # Fall back to raw query if extraction returns nothing return concepts if concepts else [query] @@ -667,8 +669,14 @@ class GraphRag: # Emit grounding explain after concept extraction if explain_callback: + cu = getattr(q, 'concepts_usage', None) gnd_triples = set_graph( - grounding_triples(gnd_uri, q_uri, concepts), + grounding_triples( + gnd_uri, q_uri, concepts, + in_token=cu.in_token if cu else None, + out_token=cu.out_token if cu else None, + model=cu.model if cu else None, + ), GRAPH_RETRIEVAL ) await explain_callback(gnd_triples, gnd_uri) @@ -883,9 +891,25 @@ class GraphRag: # Emit focus explain after edge selection completes if explain_callback: + # Sum scoring + reasoning token usage for focus event + focus_in = 0 + focus_out = 0 + focus_model = None + for r in [scoring_result, reasoning_result]: + if r is not None: + if r.in_token is not None: + focus_in += r.in_token + if r.out_token is not None: + focus_out += r.out_token + if r.model is not None: + focus_model = r.model + foc_triples = set_graph( focus_triples( - foc_uri, exp_uri, selected_edges_with_reasoning, session_id + foc_uri, exp_uri, selected_edges_with_reasoning, session_id, + in_token=focus_in or None, + out_token=focus_out or None, + model=focus_model, ), GRAPH_RETRIEVAL ) @@ -956,6 +980,9 @@ class GraphRag: synthesis_triples( syn_uri, foc_uri, document_id=synthesis_doc_id, + in_token=synthesis_result.in_token if synthesis_result else None, + out_token=synthesis_result.out_token if synthesis_result else None, + model=synthesis_result.model if synthesis_result else None, ), GRAPH_RETRIEVAL )