Expose LLM token usage (in_token, out_token, model) across all

service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
Cyber MacGeddon 2026-04-11 20:22:35 +01:00
parent ffe310af7c
commit 56d700f301
60 changed files with 1252 additions and 577 deletions

View file

@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
from trustgraph.agent.react.types import Action, Final, Tool, Argument
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
# Mock prompt client
prompt_client = AsyncMock()
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for information about machine learning
Action: knowledge_query
Args: {
"question": "What is machine learning?"
}"""
)
# Mock graph RAG client
graph_rag_client = AsyncMock()
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
# Mock text completion client
text_completion_client = AsyncMock()
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
text_completion_client.question.return_value = PromptResult(
response_type="text",
text="Machine learning involves algorithms that improve through experience."
)
# Mock MCP tool client
mcp_tool_client = AsyncMock()
@ -147,8 +154,11 @@ Args: {
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
"""Test agent manager returning final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have enough information to answer the question
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
)
question = "What is machine learning?"
history = []
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
"""Test ReAct cycle ending with final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide a direct answer
Final Answer: Machine learning is a branch of artificial intelligence."""
)
question = "What is machine learning?"
history = []
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
for tool_name, expected_service in tool_scenarios:
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: I need to use {tool_name}
Action: {tool_name}
Args: {{
"question": "test question"
}}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -284,11 +300,14 @@ Args: {{
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
"""Test agent manager error handling for unknown tool"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to use an unknown tool
Action: unknown_tool
Args: {
"param": "value"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -321,11 +340,14 @@ Args: {
question = "Find information about AI and summarize it"
# Mock multi-step reasoning
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for AI information first
Action: knowledge_query
Args: {
"question": "What is artificial intelligence?"
}"""
)
# Act
action = await agent_manager.reason(question, [], mock_flow_context)
@ -372,9 +394,12 @@ Args: {
# Format arguments as JSON
import json
args_json = json.dumps(test_case['arguments'], indent=4)
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: Using {test_case['action']}
Action: {test_case['action']}
Args: {args_json}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -507,7 +532,10 @@ Args: {
]
for test_case in test_cases:
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=test_case["response"]
)
if test_case["error_contains"]:
# Should raise an error
@ -527,13 +555,16 @@ Args: {
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
"""Test edge cases in text parsing"""
# Test response with markdown code blocks
mock_flow_context("prompt-request").agent_react.return_value = """```
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""```
Thought: I need to search for information
Action: knowledge_query
Args: {
"question": "What is AI?"
}
```"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -541,15 +572,18 @@ Args: {
assert action.name == "knowledge_query"
# Test response with extra whitespace
mock_flow_context("prompt-request").agent_react.return_value = """
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""
Thought: I need to think about this
Action: knowledge_query
Thought: I need to think about this
Action: knowledge_query
Args: {
"question": "test"
}
"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -560,7 +594,9 @@ Args: {
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
"""Test handling of multi-line thoughts and final answers"""
# Multi-line thought
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to consider multiple factors:
1. The user's question is complex
2. I should search for comprehensive information
3. This requires using the knowledge query tool
@ -568,6 +604,7 @@ Action: knowledge_query
Args: {
"question": "complex query"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -575,13 +612,16 @@ Args: {
assert "knowledge query tool" in action.thought
# Multi-line final answer
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have gathered enough information
Final Answer: Here is a comprehensive answer:
1. First point about the topic
2. Second point with details
3. Final conclusion
This covers all aspects of the question."""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -593,13 +633,16 @@ This covers all aspects of the question."""
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
"""Test JSON arguments with special characters and edge cases"""
# Test with special characters in JSON (properly escaped)
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Processing special characters
Action: knowledge_query
Args: {
"question": "What about \\"quotes\\" and 'apostrophes'?",
"context": "Line 1\\nLine 2\\tTabbed",
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -608,7 +651,9 @@ Args: {
assert "@#$%^&*" in action.arguments["special"]
# Test with nested JSON
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Complex arguments
Action: web_search
Args: {
"query": "test",
@ -621,6 +666,7 @@ Args: {
}
}
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -632,7 +678,9 @@ Args: {
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
"""Test final answers that contain JSON-like content"""
# Final answer with JSON content
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide the data in JSON format
Final Answer: {
"result": "success",
"data": {
@ -642,6 +690,7 @@ Final Answer: {
},
"confidence": 0.95
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -792,11 +841,14 @@ Final Answer: {
agent = AgentManager(tools=custom_tools, additional_context="")
# Mock response for custom collection query
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search in the research papers
Action: knowledge_query_custom
Args: {
"question": "Latest AI research?"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl
from trustgraph.agent.react.types import Tool, Argument
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks,
assert_streaming_chunks_valid,
@ -51,10 +52,10 @@ Args: {
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.agent_react.side_effect = agent_react_streaming
return client
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk + " ", is_final)
return response
return response
return PromptResult(response_type="text", text=response)
return PromptResult(response_type="text", text=response)
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Error
)
from trustgraph.agent.react.service import Processor
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from New York using structured query
Action: structured-query
Args: {
"question": "Find all customers from New York"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -173,11 +177,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query for a table that might not exist
Action: structured-query
Args: {
"question": "Find data from a table that doesn't exist"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -250,11 +257,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from California first
Action: structured-query
Args: {
"question": "Find all customers from California"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -339,11 +349,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query the sales data
Action: structured-query
Args: {
"question": "Query the sales data for recent transactions"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -447,11 +460,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to get customer information
Action: structured-query
Args: {
"question": "Get customer information and format it nicely"
}"""
)
# Set up flow context routing
def flow_context(service_name):

View file

@ -10,6 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.document_prompt.return_value = (
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
client.document_prompt.return_value = PromptResult(
response_type="text",
text=(
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
)
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
)
# Verify final response
result, usage = result
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
mock_prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="I couldn't find any relevant documents for your query."
)
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
result_text, usage = result
assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks
result_text, usage = result
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
assert result_text == full_from_chunks
# Verify content is reasonable
assert len(result) > 0
assert len(result_text) > 0
@pytest.mark.asyncio
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
non_streaming_text, _ = non_streaming_result
streaming_text, _ = streaming_result
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
)
# Assert
result_text, usage = result
assert callback.call_count > 0
assert result is not None
assert result_text is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
result_text, usage = result
assert isinstance(result_text, str)
@pytest.mark.asyncio
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
)
# Assert - Should still produce streamed response
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
@pytest.mark.asyncio
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
)
# Assert
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
# Verify doc_limit was passed correctly

View file

@ -12,6 +12,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return "" # No edges scored
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
return PromptResult(
response_type="text",
text=(
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
)
return ""
return PromptResult(response_type="text", text="")
client.prompt.side_effect = mock_prompt
return client
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
response, usage = response
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
return full_text
return ""
return PromptResult(response_type="text", text=full_text)
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
)
# Assert
response, usage = response
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_response == non_streaming_response
non_streaming_text, _ = non_streaming_response
streaming_text, _ = streaming_response
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_response
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
# Assert - Should complete without error
assert response is not None
assert isinstance(response, str)
response_text, usage = response
assert isinstance(response_text, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,

View file

@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
# Mock prompt client for definitions extraction
prompt_client = AsyncMock()
prompt_client.extract_definitions.return_value = [
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
prompt_client.extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
)
# Mock prompt client for relationships extraction
prompt_client.extract_relationships.return_value = [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
prompt_client.extract_relationships.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
)
# Mock producers for output streams
triples_producer = AsyncMock()
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of empty extraction results"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = []
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[]
)
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of invalid extraction response format"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="text",
text="invalid format"
) # Should be jsonl with objects list
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
"""Test entity filtering and validation in extraction"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = [
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
)
sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"),

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
)
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
if schema_name == "customer_records":
if "john" in text.lower():
return [
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
)
elif "jane" in text.lower():
return [
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
)
else:
return []
return PromptResult(response_type="jsonl", objects=[])
elif schema_name == "product_catalog":
if "laptop" in text.lower():
return [
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
)
elif "book" in text.lower():
return [
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
)
else:
return []
return []
return PromptResult(response_type="jsonl", objects=[])
return PromptResult(response_type="jsonl", objects=[])
prompt_client.extract_objects.side_effect = mock_extract_objects

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.prompt.template.service import Processor
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
from trustgraph.base.text_completion_client import TextCompletionResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -27,34 +28,52 @@ class TestPromptStreaming:
# Mock text completion client with streaming
text_completion_client = AsyncMock()
async def streaming_request(request, recipient=None, timeout=600):
"""Simulate streaming text completion"""
if request.streaming and recipient:
# Simulate streaming chunks
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
# Streaming chunks to send
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
for i, chunk_text in enumerate(chunks):
is_final = (i == len(chunks) - 1)
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=is_final
)
final = await recipient(response)
if final:
break
# Final empty chunk
await recipient(TextCompletionResponse(
response="",
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
"""Simulate streaming text completion via text_completion_stream"""
for i, chunk_text in enumerate(chunks):
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=True
))
end_of_stream=False
)
await handler(response)
text_completion_client.request = streaming_request
# Send final empty chunk with end_of_stream
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
return TextCompletionResult(
text=None,
in_token=10,
out_token=9,
model="test-model",
)
async def non_streaming_text_completion(system, prompt, timeout=600):
"""Simulate non-streaming text completion"""
full_text = "Machine learning is a field of artificial intelligence."
return TextCompletionResult(
text=full_text,
in_token=10,
out_token=9,
model="test-model",
)
text_completion_client.text_completion_stream = AsyncMock(
side_effect=streaming_text_completion_stream
)
text_completion_client.text_completion = AsyncMock(
side_effect=non_streaming_text_completion
)
# Mock response producer
response_producer = AsyncMock()
@ -156,14 +175,6 @@ class TestPromptStreaming:
consumer = MagicMock()
# Mock non-streaming text completion
text_completion_client = mock_flow_context_streaming("text-completion-request")
async def non_streaming_text_completion(system, prompt, streaming=False):
return "AI is the simulation of human intelligence in machines."
text_completion_client.text_completion = non_streaming_text_completion
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
@ -218,17 +229,12 @@ class TestPromptStreaming:
# Mock text completion client that raises an error
text_completion_client = AsyncMock()
async def failing_request(request, recipient=None, timeout=600):
if recipient:
# Send error response with proper Error schema
error_response = TextCompletionResponse(
response="",
error=Error(message="Text completion error", type="processing_error"),
end_of_stream=True
)
await recipient(error_response)
async def failing_stream(system, prompt, handler, timeout=600):
raise RuntimeError("Text completion error")
text_completion_client.request = failing_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=failing_stream
)
# Mock response producer to capture error response
response_producer = AsyncMock()
@ -255,22 +261,15 @@ class TestPromptStreaming:
consumer = MagicMock()
# Act - The service catches errors and sends error responses, doesn't raise
# Act - The service catches errors and sends an error PromptResponse
await prompt_processor_streaming.on_request(message, consumer, context)
# Assert - Verify error response was sent
assert response_producer.send.call_count > 0
# Check that at least one response contains an error
error_sent = False
for call in response_producer.send.call_args_list:
response = call.args[0]
if hasattr(response, 'error') and response.error:
error_sent = True
assert "Text completion error" in response.error.message
break
assert error_sent, "Expected error response to be sent"
# Assert - error response was sent
calls = response_producer.send.call_args_list
assert len(calls) > 0
error_response = calls[-1].args[0]
assert error_response.error is not None
assert "Text completion error" in error_response.error.message
@pytest.mark.asyncio
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
@ -315,21 +314,22 @@ class TestPromptStreaming:
# Mock text completion that sends empty chunks
text_completion_client = AsyncMock()
async def empty_streaming_request(request, recipient=None, timeout=600):
if request.streaming and recipient:
# Send empty chunk followed by final marker
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
async def empty_streaming(system, prompt, handler, timeout=600):
# Send empty chunk followed by final marker
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
text_completion_client.request = empty_streaming_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=empty_streaming
)
response_producer = AsyncMock()
def context_router(service_name):
@ -401,4 +401,4 @@ class TestPromptStreaming:
# Verify chunks concatenate to expected result
full_text = "".join(chunk_texts)
assert full_text == "Machine learning is a field of artificial intelligence"
assert full_text == "Machine learning is a field of artificial intelligence."

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
from trustgraph.base import PromptResult
class TestGraphRagStreamingProtocol:
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Edge selection returns empty (no edges selected)
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
return PromptResult(response_type="text", text="")
else:
return "The answer is here."
return ""
return PromptResult(response_type="text", text="The answer is here.")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
await chunk_callback("Document", False)
await chunk_callback(" summary", False)
await chunk_callback(".", True) # Non-empty final chunk
return ""
return PromptResult(response_type="text", text="")
else:
return "Document summary."
return PromptResult(response_type="text", text="Document summary.")
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
return ""
return PromptResult(response_type="text", text="")
else:
return "textmore"
return ""
return PromptResult(response_type="text", text="textmore")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_with_empties

View file

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
from trustgraph.base import PromptResult
def _make_config(patterns=None, task_types=None):
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
client.prompt = AsyncMock(
return_value=PromptResult(response_type="text", text=prompt_response)
)
def context(service_name):
return client
@ -274,8 +277,8 @@ class TestRoute:
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
return PromptResult(response_type="text", text="research")
return PromptResult(response_type="text", text="plan-then-execute")
client.prompt = mock_prompt
context = lambda name: client

View file

@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(action)
return action
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(Action(
thought="done", name="final",
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
],
)
def flow_factory(name):
if name == "prompt-request":
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
mock_prompt_client.prompt.return_value = PromptResult(
response_type="json",
object={
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
},
)
def flow_factory(name):
if name == "prompt-request":
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
def flow_factory(name):
if name == "prompt-request":
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
"What is quantum computing?",
"What are qubits?",
],
)
def flow_factory(name):
if name == "prompt-request":
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
def flow_factory(name):
if name == "prompt-request":
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B", "Goal C"],
)
def flow_factory(name):
if name == "prompt-request":

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
from trustgraph.base import PromptResult
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
if isinstance(prompt_result, list):
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
else:
wrapped = PromptResult(response_type="text", text=prompt_result)
mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result
return_value=wrapped
)
def flow(name):

View file

@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
# ---------------------------------------------------------------------------
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result
return_value=PromptResult(
response_type="jsonl",
objects=prompt_result,
)
)
def flow(name):

View file

@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
from trustgraph.base import PromptResult
# Sample chunk content mapping for tests
@ -132,7 +133,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
query = Query(
rag=mock_rag,
@ -157,7 +158,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@ -258,7 +259,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -273,7 +274,7 @@ class TestQuery:
expected_response = "This is the document RAG response"
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -315,7 +316,8 @@ class TestQuery:
assert "Relevant document content" in docs
assert "Another document" in docs
assert result == expected_response
result_text, usage = result
assert result_text == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
@ -325,7 +327,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
@ -333,7 +335,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -352,7 +354,8 @@ class TestQuery:
collection="default" # Default collection
)
assert result == "Default response"
result_text, usage = result
assert result_text == "Default response"
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
@ -401,7 +404,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
@ -409,7 +412,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -428,7 +431,8 @@ class TestQuery:
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response"
result_text, usage = result
assert result_text == "Verbose RAG response"
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
@ -469,11 +473,11 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -490,7 +494,8 @@ class TestQuery:
documents=[]
)
assert result == "No documents found response"
result_text, usage = result
assert result_text == "No documents found response"
@pytest.mark.asyncio
async def test_get_vectors_with_verbose(self):
@ -525,7 +530,7 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
@ -541,7 +546,7 @@ class TestQuery:
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -584,7 +589,8 @@ class TestQuery:
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
assert result == final_response
result_text, usage = result
assert result_text == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -89,8 +90,8 @@ def build_mock_clients():
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return "return policy\nrefund"
return ""
return PromptResult(response_type="text", text="return policy\nrefund")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -113,8 +114,9 @@ def build_mock_clients():
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
prompt_client.document_prompt.return_value = (
"Items can be returned within 30 days for a full refund."
prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
assert result == "Items can be returned within 30 days for a full refund."
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund."
result_text, usage = await rag.query(query="What is the return policy?")
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -34,7 +34,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "test response"
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom user/collection
msg = MagicMock()
@ -97,7 +97,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A document about cats."
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
# Setup message with non-streaming request
msg = MagicMock()
@ -130,4 +130,5 @@ class TestDocumentRagService:
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.end_of_session is True
assert sent_response.error is None

View file

@ -7,6 +7,7 @@ import unittest.mock
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
from trustgraph.base import PromptResult
class TestGraphRag:
@ -172,7 +173,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
query = Query(
rag=mock_rag,
@ -196,7 +197,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@ -220,7 +221,7 @@ class TestQuery:
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -565,14 +566,14 @@ class TestQuery:
# Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
elif prompt_name == "kg-synthesis":
return expected_response
return ""
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt = mock_prompt
@ -607,7 +608,8 @@ class TestQuery:
explain_callback=collect_provenance
)
assert response == expected_response
response_text, usage = response
assert response_text == expected_response
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5

View file

@ -13,6 +13,7 @@ from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -136,24 +137,36 @@ def build_mock_clients():
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return prompt_responses["extract-concepts"]
return PromptResult(
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis":
return synthesis_answer
return ""
return PromptResult(
response_type="text",
text=synthesis_answer,
)
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
clients = build_mock_clients()
rag = GraphRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -44,7 +44,7 @@ class TestGraphRagService:
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
await explain_callback([], "urn:trustgraph:prov:selection:test")
await explain_callback([], "urn:trustgraph:prov:answer:test")
return "A small domesticated mammal."
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -79,8 +79,8 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
assert mock_response_producer.send.call_count == 6
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
assert mock_response_producer.send.call_count == 5
# First 4 messages are explain (emitted in real-time during query)
for i in range(4):
@ -88,17 +88,12 @@ class TestGraphRagService:
assert prov_msg.message_type == "explain"
assert prov_msg.explain_id is not None
# 5th message is chunk with response
# 5th message is chunk with response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "A small domesticated mammal."
assert chunk_msg.end_of_stream is True
# 6th message is empty chunk with end_of_session=True
close_msg = mock_response_producer.send.call_args_list[5][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True
# Verify provenance triples were sent to provenance queue
assert mock_provenance_producer.send.call_count == 4
@ -187,7 +182,7 @@ class TestGraphRagService:
async def mock_query(**kwargs):
# Don't call explain_callback
return "Response text"
return "Response text", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -218,17 +213,12 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 2 messages (chunk + empty chunk to close)
assert mock_response_producer.send.call_count == 2
# Verify: 1 combined message (chunk with end_of_session)
assert mock_response_producer.send.call_count == 1
# First is the response chunk
# Single message has response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "Response text"
assert chunk_msg.end_of_stream is True
# Second is empty chunk to close session
close_msg = mock_response_producer.send.call_args_list[1][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True

View file

@ -107,6 +107,7 @@ from .types import (
AgentObservation,
AgentAnswer,
RAGChunk,
TextCompletionResult,
ProvenanceEvent,
)
@ -185,6 +186,7 @@ __all__ = [
"AgentObservation",
"AgentAnswer",
"RAGChunk",
"TextCompletionResult",
"ProvenanceEvent",
# Exceptions

View file

@ -14,6 +14,8 @@ import aiohttp
import json
from typing import Optional, Dict, Any, List
from . types import TextCompletionResult
from . exceptions import ProtocolException, ApplicationException
@ -434,12 +436,11 @@ class AsyncFlowInstance:
return await self.request("agent", request_data)
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult:
"""
Generate text completion (non-streaming).
Generates a text response from an LLM given a system prompt and user prompt.
Returns the complete response text.
Note: This method does not support streaming. For streaming text generation,
use AsyncSocketFlowInstance.text_completion() instead.
@ -450,19 +451,19 @@ class AsyncFlowInstance:
**kwargs: Additional service-specific parameters
Returns:
str: Complete generated text response
TextCompletionResult: Result with text, in_token, out_token, model
Example:
```python
async_flow = await api.async_flow()
flow = async_flow.id("default")
# Generate text
response = await flow.text_completion(
result = await flow.text_completion(
system="You are a helpful assistant.",
prompt="Explain quantum computing in simple terms."
)
print(response)
print(result.text)
print(f"Tokens: {result.in_token} in, {result.out_token} out")
```
"""
request_data = {
@ -473,7 +474,12 @@ class AsyncFlowInstance:
request_data.update(kwargs)
result = await self.request("text-completion", request_data)
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,

View file

@ -4,7 +4,7 @@ import asyncio
import websockets
from typing import Optional, Dict, Any, AsyncIterator, Union
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult
from . exceptions import ProtocolException, ApplicationException
@ -199,7 +199,10 @@ class AsyncSocketClient:
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False)
end_of_dialog=resp.get("end_of_dialog", False),
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
elif chunk_type == "action":
return AgentThought(
@ -211,7 +214,10 @@ class AsyncSocketClient:
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
error=None
error=None,
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
async def aclose(self):
@ -269,7 +275,11 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("agent", self.flow_id, request)
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
"""Text completion with optional streaming"""
"""Text completion with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"system": system,
"prompt": prompt,
@ -281,13 +291,18 @@ class AsyncSocketFlowInstance:
return self._text_completion_streaming(request)
else:
result = await self.client._send_request("text-completion", self.flow_id, request)
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
async def _text_completion_streaming(self, request):
"""Helper for streaming text completion"""
"""Helper for streaming text completion. Yields RAGChunk objects."""
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
if hasattr(chunk, 'content'):
yield chunk.content
if isinstance(chunk, RAGChunk):
yield chunk
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,

View file

@ -11,7 +11,7 @@ import base64
from .. knowledge import hash, Uri, Literal, QuotedTriple
from .. schema import IRI, LITERAL, TRIPLE
from . types import Triple
from . types import Triple, TextCompletionResult
from . exceptions import ProtocolException
@ -360,16 +360,17 @@ class FlowInstance:
prompt: User prompt/question
Returns:
str: Generated response text
TextCompletionResult: Result with text, in_token, out_token, model
Example:
```python
flow = api.flow().id("default")
response = flow.text_completion(
result = flow.text_completion(
system="You are a helpful assistant",
prompt="What is quantum computing?"
)
print(response)
print(result.text)
print(f"Tokens: {result.in_token} in, {result.out_token} out")
```
"""
@ -379,10 +380,17 @@ class FlowInstance:
"prompt": prompt
}
return self.request(
result = self.request(
"service/text-completion",
input
)["response"]
)
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def agent(self, question, user="trustgraph", state=None, group=None, history=None):
"""
@ -498,10 +506,17 @@ class FlowInstance:
"edge-limit": edge_limit,
}
return self.request(
result = self.request(
"service/graph-rag",
input
)["response"]
)
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def document_rag(
self, query, user="trustgraph", collection="default",
@ -543,10 +558,17 @@ class FlowInstance:
"doc-limit": doc_limit,
}
return self.request(
result = self.request(
"service/document-rag",
input
)["response"]
)
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def embeddings(self, texts):
"""

View file

@ -14,7 +14,7 @@ import websockets
from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult
from . exceptions import ProtocolException, raise_from_error_dict
@ -393,6 +393,9 @@ class SocketClient:
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False),
message_id=resp.get("message_id", ""),
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
elif chunk_type == "action":
return AgentThought(
@ -404,7 +407,10 @@ class SocketClient:
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
error=None
error=None,
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
@ -543,8 +549,12 @@ class SocketFlowInstance:
streaming=True, include_provenance=True
)
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
"""Execute text completion with optional streaming."""
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute text completion with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"system": system,
"prompt": prompt,
@ -557,12 +567,17 @@ class SocketFlowInstance:
if streaming:
return self._text_completion_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
if isinstance(chunk, RAGChunk):
yield chunk
def graph_rag(
self,
@ -577,8 +592,12 @@ class SocketFlowInstance:
edge_limit: int = 25,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Execute graph-based RAG query with optional streaming."""
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute graph-based RAG query with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"query": query,
"user": user,
@ -598,7 +617,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def graph_rag_explain(
self,
@ -642,8 +666,12 @@ class SocketFlowInstance:
doc_limit: int = 10,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Execute document-based RAG query with optional streaming."""
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute document-based RAG query with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"query": query,
"user": user,
@ -658,7 +686,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def document_rag_explain(
self,
@ -684,10 +717,10 @@ class SocketFlowInstance:
streaming=True, include_provenance=True
)
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
if isinstance(chunk, RAGChunk):
yield chunk
def prompt(
self,
@ -695,8 +728,12 @@ class SocketFlowInstance:
variables: Dict[str, str],
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Execute a prompt template with optional streaming."""
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute a prompt template with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"id": id,
"variables": variables,
@ -709,7 +746,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("text", result.get("response", "")),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def graph_embeddings_query(
self,

View file

@ -189,6 +189,9 @@ class AgentAnswer(StreamingChunk):
chunk_type: str = "final-answer"
end_of_dialog: bool = False
message_id: str = ""
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass
class RAGChunk(StreamingChunk):
@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk):
content: Generated text content
end_of_stream: True if this is the final chunk of the stream
error: Optional error information if an error occurred
in_token: Input token count (populated on the final chunk, 0 otherwise)
out_token: Output token count (populated on the final chunk, 0 otherwise)
model: Model identifier (populated on the final chunk, empty otherwise)
chunk_type: Always "rag"
"""
chunk_type: str = "rag"
end_of_stream: bool = False
error: Optional[Dict[str, str]] = None
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass
class TextCompletionResult:
"""
Result from a text completion request.
Returned by text_completion() in both streaming and non-streaming modes.
In streaming mode, text is None (chunks are delivered via the iterator).
In non-streaming mode, text contains the complete response.
Attributes:
text: Complete response text (None in streaming mode)
in_token: Input token count (None if not available)
out_token: Output token count (None if not available)
model: Model identifier (None if not available)
"""
text: Optional[str]
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass
class ProvenanceEvent:

View file

@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec
from . prompt_client import PromptClientSpec
from . text_completion_client import (
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
)
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService

View file

@ -1,10 +1,22 @@
import json
import asyncio
from dataclasses import dataclass
from typing import Optional, Any
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
@dataclass
class PromptResult:
response_type: str # "text", "json", or "jsonl"
text: Optional[str] = None # populated for "text"
object: Any = None # populated for "json"
objects: Optional[list] = None # populated for "jsonl"
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text: return resp.text
if resp.text:
return PromptResult(
response_type="text",
text=resp.text,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return json.loads(resp.object)
parsed = json.loads(resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
else:
last_text = ""
last_object = None
last_resp = None
async def forward_chunks(resp):
nonlocal last_text, last_object
nonlocal last_resp
if resp.error:
raise RuntimeError(resp.error.message)
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
end_stream = getattr(resp, 'end_of_stream', False)
if resp.text is not None:
last_text = resp.text
if chunk_callback:
if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text, end_stream)
else:
chunk_callback(resp.text, end_stream)
elif resp.object:
last_object = resp.object
last_resp = resp
return end_stream
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
timeout=timeout
)
if last_text:
return last_text
if last_resp is None:
return PromptResult(response_type="text")
return json.loads(last_object) if last_object else None
if last_resp.object:
parsed = json.loads(last_resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="text",
text=last_resp.text,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
response_schema = PromptResponse,
impl = PromptClient,
)

View file

@ -1,47 +1,71 @@
from dataclasses import dataclass
from typing import Optional
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
@dataclass
class TextCompletionResult:
text: Optional[str]
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, streaming=False, timeout=600):
# If not streaming, use original behavior
if not streaming:
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = False
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
async def text_completion(self, system, prompt, timeout=600):
return resp.response
# For streaming: collect all chunks and return complete response
full_response = ""
async def collect_chunks(resp):
nonlocal full_response
if resp.error:
raise RuntimeError(resp.error.message)
if resp.response:
full_response += resp.response
# Return True when end_of_stream is reached
return getattr(resp, 'end_of_stream', False)
await self.request(
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
system = system, prompt = prompt, streaming = False
),
recipient=collect_chunks,
timeout=timeout
)
return full_response
if resp.error:
raise RuntimeError(resp.error.message)
return TextCompletionResult(
text = resp.response,
in_token = resp.in_token,
out_token = resp.out_token,
model = resp.model,
)
async def text_completion_stream(
self, system, prompt, handler, timeout=600,
):
"""
Streaming text completion. `handler` is an async callable invoked
once per chunk with the chunk's TextCompletionResponse. Returns a
TextCompletionResult with text=None and token counts / model taken
from the end_of_stream message.
"""
async def on_chunk(resp):
if resp.error:
raise RuntimeError(resp.error.message)
await handler(resp)
return getattr(resp, "end_of_stream", False)
final = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
),
recipient=on_chunk,
timeout=timeout,
)
return TextCompletionResult(
text = None,
in_token = final.in_token,
out_token = final.out_token,
model = final.model,
)
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
response_schema = TextCompletionResponse,
impl = TextCompletionClient,
)

View file

@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator):
# Always include end_of_stream flag for streaming support
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator):
def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]:
result = {"response": obj.response}
if obj.in_token:
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token:
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model:
if obj.model is not None:
result["model"] = obj.model
# Always include end_of_stream flag for streaming support

View file

@ -66,5 +66,10 @@ class AgentResponse:
error: Error | None = None
# Token usage (populated on end_of_dialog message)
in_token: int | None = None
out_token: int | None = None
model: str | None = None
############################################################################

View file

@ -17,9 +17,9 @@ class TextCompletionRequest:
class TextCompletionResponse:
error: Error | None = None
response: str = ""
in_token: int = 0
out_token: int = 0
model: str = ""
in_token: int | None = None
out_token: int | None = None
model: str | None = None
end_of_stream: bool = False # Indicates final message in stream
############################################################################

View file

@ -41,4 +41,9 @@ class PromptResponse:
# Indicates final message in stream
end_of_stream: bool = False
# Token usage from the underlying text completion
in_token: int | None = None
out_token: int | None = None
model: str | None = None
############################################################################

View file

@ -29,6 +29,9 @@ class GraphRagResponse:
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete
in_token: int | None = None
out_token: int | None = None
model: str | None = None
############################################################################
@ -52,3 +55,6 @@ class DocumentRagResponse:
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete
in_token: int | None = None
out_token: int | None = None
model: str | None = None

View file

@ -272,7 +272,8 @@ def question(
url, question, flow_id, user, collection,
plan=None, state=None, group=None, pattern=None,
verbose=False, streaming=True,
token=None, explainable=False, debug=False
token=None, explainable=False, debug=False,
show_usage=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
@ -323,6 +324,7 @@ def question(
# Track last chunk type and current outputter for streaming
last_chunk_type = None
current_outputter = None
last_answer_chunk = None
for chunk in response:
chunk_type = chunk.chunk_type
@ -357,6 +359,7 @@ def question(
current_outputter.word_buffer = ""
elif chunk_type == "final-answer":
print(content, end="", flush=True)
last_answer_chunk = chunk
# Close any remaining outputter
if current_outputter:
@ -366,6 +369,14 @@ def question(
elif last_chunk_type == "final-answer":
print()
if show_usage and last_answer_chunk:
print(
f"Input tokens: {last_answer_chunk.in_token} "
f"Output tokens: {last_answer_chunk.out_token} "
f"Model: {last_answer_chunk.model}",
file=sys.stderr,
)
else:
# Non-streaming response - but agents use multipart messaging
# so we iterate through the chunks (which are complete messages, not text chunks)
@ -477,6 +488,12 @@ def main():
help='Show debug output for troubleshooting'
)
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args()
try:
@ -496,6 +513,7 @@ def main():
token = args.token,
explainable = args.explainable,
debug = args.debug,
show_usage = args.show_usage,
)
except Exception as e:

View file

@ -99,7 +99,8 @@ def question_explainable(
def question(
url, flow_id, question_text, user, collection, doc_limit,
streaming=True, token=None, explainable=False, debug=False
streaming=True, token=None, explainable=False, debug=False,
show_usage=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
@ -133,22 +134,40 @@ def question(
)
# Stream output
last_chunk = None
for chunk in response:
print(chunk, end="", flush=True)
print(chunk.content, end="", flush=True)
last_chunk = chunk
print() # Final newline
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=sys.stderr,
)
finally:
socket.close()
else:
# Use REST API for non-streaming
flow = api.flow().id(flow_id)
resp = flow.document_rag(
result = flow.document_rag(
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
)
print(resp)
print(result.text)
if show_usage:
print(
f"Input tokens: {result.in_token} "
f"Output tokens: {result.out_token} "
f"Model: {result.model}",
file=sys.stderr,
)
def main():
@ -219,6 +238,12 @@ def main():
help='Show debug output for troubleshooting'
)
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args()
try:
@ -234,6 +259,7 @@ def main():
token=args.token,
explainable=args.explainable,
debug=args.debug,
show_usage=args.show_usage,
)
except Exception as e:

View file

@ -753,7 +753,7 @@ def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None,
explainable=False, debug=False
explainable=False, debug=False, show_usage=False
):
# Explainable mode uses the API to capture and process provenance events
@ -798,16 +798,26 @@ def question(
)
# Stream output
last_chunk = None
for chunk in response:
print(chunk, end="", flush=True)
print(chunk.content, end="", flush=True)
last_chunk = chunk
print() # Final newline
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=sys.stderr,
)
finally:
socket.close()
else:
# Use REST API for non-streaming
flow = api.flow().id(flow_id)
resp = flow.graph_rag(
result = flow.graph_rag(
query=question,
user=user,
collection=collection,
@ -818,7 +828,15 @@ def question(
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
)
print(resp)
print(result.text)
if show_usage:
print(
f"Input tokens: {result.in_token} "
f"Output tokens: {result.out_token} "
f"Model: {result.model}",
file=sys.stderr,
)
def main():
@ -923,6 +941,12 @@ def main():
help='Show debug output for troubleshooting'
)
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args()
try:
@ -943,6 +967,7 @@ def main():
token=args.token,
explainable=args.explainable,
debug=args.debug,
show_usage=args.show_usage,
)
except Exception as e:

View file

@ -10,7 +10,8 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def query(url, flow_id, system, prompt, streaming=True, token=None):
def query(url, flow_id, system, prompt, streaming=True, token=None,
show_usage=False):
# Create API client
api = Api(url=url, token=token)
@ -26,14 +27,29 @@ def query(url, flow_id, system, prompt, streaming=True, token=None):
)
if streaming:
# Stream output to stdout without newline
last_chunk = None
for chunk in response:
print(chunk, end="", flush=True)
# Add final newline after streaming
print(chunk.content, end="", flush=True)
last_chunk = chunk
print()
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=__import__('sys').stderr,
)
else:
# Non-streaming: print complete response
print(response)
print(response.text)
if show_usage:
print(
f"Input tokens: {response.in_token} "
f"Output tokens: {response.out_token} "
f"Model: {response.model}",
file=__import__('sys').stderr,
)
finally:
# Clean up socket connection
@ -82,6 +98,12 @@ def main():
help='Disable streaming (default: streaming enabled)'
)
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args()
try:
@ -93,6 +115,7 @@ def main():
prompt=args.prompt[0],
streaming=not args.no_streaming,
token=args.token,
show_usage=args.show_usage,
)
except Exception as e:

View file

@ -15,7 +15,8 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def query(url, flow_id, template_id, variables, streaming=True, token=None):
def query(url, flow_id, template_id, variables, streaming=True, token=None,
show_usage=False):
# Create API client
api = Api(url=url, token=token)
@ -31,16 +32,30 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None):
)
if streaming:
# Stream output (prompt yields strings directly)
last_chunk = None
for chunk in response:
if chunk:
print(chunk, end="", flush=True)
# Add final newline after streaming
if chunk.content:
print(chunk.content, end="", flush=True)
last_chunk = chunk
print()
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=__import__('sys').stderr,
)
else:
# Non-streaming: print complete response
print(response)
print(response.text)
if show_usage:
print(
f"Input tokens: {response.in_token} "
f"Output tokens: {response.out_token} "
f"Model: {response.model}",
file=__import__('sys').stderr,
)
finally:
# Clean up socket connection
@ -92,6 +107,12 @@ specified multiple times''',
help='Disable streaming (default: streaming enabled for text responses)'
)
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args()
variables = {}
@ -113,6 +134,7 @@ specified multiple times''',
variables=variables,
streaming=not args.no_streaming,
token=args.token,
show_usage=args.show_usage,
)
except Exception as e:

View file

@ -53,7 +53,7 @@ class MetaRouter:
"general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""},
}
async def identify_task_type(self, question, context):
async def identify_task_type(self, question, context, usage=None):
"""
Use the LLM to classify the question into one of the known task types.
@ -71,7 +71,7 @@ class MetaRouter:
try:
client = context("prompt-request")
response = await client.prompt(
result = await client.prompt(
id="task-type-classify",
variables={
"question": question,
@ -81,7 +81,9 @@ class MetaRouter:
],
},
)
selected = response.strip().lower().replace('"', '').replace("'", "")
if usage:
usage.track(result)
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in self.task_types:
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
@ -100,7 +102,7 @@ class MetaRouter:
)
return DEFAULT_TASK_TYPE, framing
async def select_pattern(self, question, task_type, context):
async def select_pattern(self, question, task_type, context, usage=None):
"""
Use the LLM to select the best execution pattern for this task type.
@ -120,7 +122,7 @@ class MetaRouter:
try:
client = context("prompt-request")
response = await client.prompt(
result = await client.prompt(
id="pattern-select",
variables={
"question": question,
@ -133,7 +135,9 @@ class MetaRouter:
],
},
)
selected = response.strip().lower().replace('"', '').replace("'", "")
if usage:
usage.track(result)
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in valid_patterns:
logger.info(f"MetaRouter: selected pattern '{selected}'")
@ -148,19 +152,20 @@ class MetaRouter:
logger.warning(f"MetaRouter: pattern selection failed: {e}")
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
async def route(self, question, context):
async def route(self, question, context, usage=None):
"""
Full routing pipeline: identify task type, then select pattern.
Args:
question: The user's query.
context: UserAwareContext (flow wrapper).
usage: Optional UsageTracker for token counting.
Returns:
(pattern, task_type, framing) tuple.
"""
task_type, framing = await self.identify_task_type(question, context)
pattern = await self.select_pattern(question, task_type, context)
task_type, framing = await self.identify_task_type(question, context, usage=usage)
pattern = await self.select_pattern(question, task_type, context, usage=usage)
logger.info(
f"MetaRouter: route result — "
f"pattern={pattern}, task_type={task_type}, framing={framing!r}"

View file

@ -65,6 +65,37 @@ class UserAwareContext:
return client
class UsageTracker:
"""Accumulates token usage across multiple prompt calls."""
def __init__(self):
self.total_in = 0
self.total_out = 0
self.last_model = None
def track(self, result):
"""Track usage from a PromptResult."""
if result is not None:
if getattr(result, "in_token", None) is not None:
self.total_in += result.in_token
if getattr(result, "out_token", None) is not None:
self.total_out += result.out_token
if getattr(result, "model", None) is not None:
self.last_model = result.model
@property
def in_token(self):
return self.total_in if self.total_in > 0 else None
@property
def out_token(self):
return self.total_out if self.total_out > 0 else None
@property
def model(self):
return self.last_model
class PatternBase:
"""
Shared infrastructure for all agent patterns.
@ -571,7 +602,8 @@ class PatternBase:
# ---- Response helpers ---------------------------------------------------
async def prompt_as_answer(self, client, prompt_id, variables,
respond, streaming, message_id=""):
respond, streaming, message_id="",
usage=None):
"""Call a prompt template, forwarding chunks as answer
AgentResponse messages when streaming is enabled.
@ -591,22 +623,28 @@ class PatternBase:
message_id=message_id,
))
await client.prompt(
result = await client.prompt(
id=prompt_id,
variables=variables,
streaming=True,
chunk_callback=on_chunk,
)
if usage:
usage.track(result)
return "".join(accumulated)
else:
return await client.prompt(
result = await client.prompt(
id=prompt_id,
variables=variables,
)
if usage:
usage.track(result)
return result.text
async def send_final_response(self, respond, streaming, answer_text,
already_streamed=False, message_id=""):
already_streamed=False, message_id="",
usage=None):
"""Send the answer content and end-of-dialog marker.
Args:
@ -614,7 +652,16 @@ class PatternBase:
via streaming callbacks (e.g. ReactPattern). Only the
end-of-dialog marker is emitted.
message_id: Provenance URI for the answer entity.
usage: UsageTracker with accumulated token counts.
"""
usage_kwargs = {}
if usage:
usage_kwargs = {
"in_token": usage.in_token,
"out_token": usage.out_token,
"model": usage.model,
}
if streaming and not already_streamed:
# Answer wasn't streamed yet — send it as a chunk first
if answer_text:
@ -626,13 +673,14 @@ class PatternBase:
message_id=message_id,
))
if streaming:
# End-of-dialog marker
# End-of-dialog marker with usage
await respond(AgentResponse(
chunk_type="answer",
content="",
end_of_message=True,
end_of_dialog=True,
message_id=message_id,
**usage_kwargs,
))
else:
await respond(AgentResponse(
@ -641,6 +689,7 @@ class PatternBase:
end_of_message=True,
end_of_dialog=True,
message_id=message_id,
**usage_kwargs,
))
def build_next_request(self, request, history, session_id, collection,

View file

@ -18,7 +18,7 @@ from trustgraph.provenance import (
agent_synthesis_uri,
)
from . pattern_base import PatternBase
from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@ -35,7 +35,10 @@ class PlanThenExecutePattern(PatternBase):
Subsequent calls execute the next pending plan step via ReACT.
"""
async def iterate(self, request, respond, next, flow):
async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -67,13 +70,13 @@ class PlanThenExecutePattern(PatternBase):
await self._planning_iteration(
request, respond, next, flow,
session_id, collection, streaming, session_uri,
iteration_num,
iteration_num, usage=usage,
)
else:
await self._execution_iteration(
request, respond, next, flow,
session_id, collection, streaming, session_uri,
iteration_num, plan,
iteration_num, plan, usage=usage,
)
def _extract_plan(self, history):
@ -98,7 +101,7 @@ class PlanThenExecutePattern(PatternBase):
async def _planning_iteration(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num):
session_uri, iteration_num, usage=None):
"""Ask the LLM to produce a structured plan."""
think = self.make_think_callback(respond, streaming)
@ -113,7 +116,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Use the plan-create prompt template
plan_steps = await client.prompt(
result = await client.prompt(
id="plan-create",
variables={
"question": request.question,
@ -124,7 +127,10 @@ class PlanThenExecutePattern(PatternBase):
],
},
)
if usage:
usage.track(result)
plan_steps = result.objects
# Validate we got a list
if not isinstance(plan_steps, list) or not plan_steps:
logger.warning("plan-create returned invalid result, falling back to single step")
@ -187,7 +193,8 @@ class PlanThenExecutePattern(PatternBase):
async def _execution_iteration(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num, plan):
session_uri, iteration_num, plan,
usage=None):
"""Execute the next pending plan step via single-shot tool call."""
pending_idx = self._find_next_pending_step(plan)
@ -198,6 +205,7 @@ class PlanThenExecutePattern(PatternBase):
request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num, plan,
usage=usage,
)
return
@ -240,7 +248,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Single-shot: ask LLM which tool + arguments to use for this goal
tool_call = await client.prompt(
result = await client.prompt(
id="plan-step-execute",
variables={
"goal": goal,
@ -258,7 +266,10 @@ class PlanThenExecutePattern(PatternBase):
],
},
)
if usage:
usage.track(result)
tool_call = result.object
tool_name = tool_call.get("tool", "")
tool_arguments = tool_call.get("arguments", {})
@ -330,7 +341,8 @@ class PlanThenExecutePattern(PatternBase):
async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num, plan):
session_uri, iteration_num, plan,
usage=None):
"""Synthesise a final answer from all completed plan step results."""
think = self.make_think_callback(respond, streaming)
@ -365,6 +377,7 @@ class PlanThenExecutePattern(PatternBase):
respond=respond,
streaming=streaming,
message_id=synthesis_msg_id,
usage=usage,
)
# Emit synthesis provenance (links back to last step result)
@ -380,4 +393,5 @@ class PlanThenExecutePattern(PatternBase):
await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id,
usage=usage,
)

View file

@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager
from ..react.types import Action, Final
from ..tool_filter import get_next_state
from . pattern_base import PatternBase
from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@ -37,7 +37,10 @@ class ReactPattern(PatternBase):
result is appended to history and a next-request is emitted.
"""
async def iterate(self, request, respond, next, flow):
async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -121,6 +124,7 @@ class ReactPattern(PatternBase):
context=context,
streaming=streaming,
on_action=on_action,
usage=usage,
)
logger.debug(f"Action: {act}")
@ -144,6 +148,7 @@ class ReactPattern(PatternBase):
await self.send_final_response(
respond, streaming, f, already_streamed=streaming,
message_id=answer_msg_id,
usage=usage,
)
return

View file

@ -23,6 +23,7 @@ from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ..orchestrator.pattern_base import UsageTracker
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
@ -493,6 +494,8 @@ class Processor(AgentService):
async def agent_request(self, request, respond, next, flow):
usage = UsageTracker()
try:
# Intercept subagent completion messages
@ -516,7 +519,7 @@ class Processor(AgentService):
if self.meta_router:
pattern, task_type, framing = await self.meta_router.route(
request.question, context,
request.question, context, usage=usage,
)
else:
pattern = "react"
@ -536,16 +539,16 @@ class Processor(AgentService):
# Dispatch to the selected pattern
if pattern == "plan-then-execute":
await self.plan_pattern.iterate(
request, respond, next, flow,
request, respond, next, flow, usage=usage,
)
elif pattern == "supervisor":
await self.supervisor_pattern.iterate(
request, respond, next, flow,
request, respond, next, flow, usage=usage,
)
else:
# Default to react
await self.react_pattern.iterate(
request, respond, next, flow,
request, respond, next, flow, usage=usage,
)
except Exception as e:

View file

@ -22,7 +22,7 @@ from trustgraph.provenance import (
agent_synthesis_uri,
)
from . pattern_base import PatternBase
from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@ -38,7 +38,10 @@ class SupervisorPattern(PatternBase):
- "synthesise": triggered by aggregator with results in subagent_results
"""
async def iterate(self, request, respond, next, flow):
async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -72,17 +75,19 @@ class SupervisorPattern(PatternBase):
request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num,
usage=usage,
)
else:
await self._decompose_and_fanout(
request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num,
usage=usage,
)
async def _decompose_and_fanout(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num):
session_uri, iteration_num, usage=None):
"""Decompose the question into sub-goals and fan out subagents."""
decompose_msg_id = agent_decomposition_uri(session_id)
@ -100,7 +105,7 @@ class SupervisorPattern(PatternBase):
client = context("prompt-request")
# Use the supervisor-decompose prompt template
goals = await client.prompt(
result = await client.prompt(
id="supervisor-decompose",
variables={
"question": request.question,
@ -112,7 +117,10 @@ class SupervisorPattern(PatternBase):
],
},
)
if usage:
usage.track(result)
goals = result.objects
# Validate result
if not isinstance(goals, list):
goals = []
@ -175,7 +183,7 @@ class SupervisorPattern(PatternBase):
async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num):
session_uri, iteration_num, usage=None):
"""Synthesise final answer from subagent results."""
synthesis_msg_id = agent_synthesis_uri(session_id)
@ -216,6 +224,7 @@ class SupervisorPattern(PatternBase):
respond=respond,
streaming=streaming,
message_id=synthesis_msg_id,
usage=usage,
)
# Emit synthesis provenance (links back to all findings)
@ -231,4 +240,5 @@ class SupervisorPattern(PatternBase):
await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id,
usage=usage,
)

View file

@ -170,7 +170,7 @@ class AgentManager:
raise ValueError(f"Could not parse response: {text}")
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None):
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None):
logger.debug(f"calling reason: {question}")
@ -255,11 +255,13 @@ class AgentManager:
client = context("prompt-request")
# Get streaming response
response_text = await client.agent_react(
prompt_result = await client.agent_react(
variables=variables,
streaming=True,
chunk_callback=on_chunk
)
if usage:
usage.track(prompt_result)
# Finalize parser
parser.finalize()
@ -275,10 +277,13 @@ class AgentManager:
# Non-streaming path - get complete text and parse
client = context("prompt-request")
response_text = await client.agent_react(
prompt_result = await client.agent_react(
variables=variables,
streaming=False
)
if usage:
usage.track(prompt_result)
response_text = prompt_result.text
logger.debug(f"Response text:\n{response_text}")
@ -292,7 +297,8 @@ class AgentManager:
raise RuntimeError(f"Failed to parse agent response: {e}")
async def react(self, question, history, think, observe, context,
streaming=False, answer=None, on_action=None):
streaming=False, answer=None, on_action=None,
usage=None):
act = await self.reason(
question = question,
@ -302,6 +308,7 @@ class AgentManager:
think = think,
observe = observe,
answer = answer,
usage = usage,
)
if isinstance(act, Final):

View file

@ -78,9 +78,10 @@ class TextCompletionImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug("Prompt question...")
return await client.question(
result = await client.question(
arguments.get("question")
)
return result.text
# This tool implementation knows how to do MCP tool invocation. This uses
# the mcp-tool service.
@ -227,10 +228,11 @@ class PromptImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug(f"Prompt template invocation: {self.template_id}...")
return await client.prompt(
result = await client.prompt(
id=self.template_id,
variables=arguments
)
return result.text
# This tool implementation invokes a dynamically configured tool service

View file

@ -117,10 +117,11 @@ class Processor(FlowProcessor):
try:
defs = await flow("prompt-request").extract_definitions(
result = await flow("prompt-request").extract_definitions(
text = chunk
)
defs = result.objects
logger.debug(f"Definitions response: {defs}")
if type(defs) != list:

View file

@ -376,10 +376,11 @@ class Processor(FlowProcessor):
"""
try:
# Call prompt service with simplified format prompt
extraction_response = await flow("prompt-request").prompt(
result = await flow("prompt-request").prompt(
id="extract-with-ontologies",
variables=prompt_variables
)
extraction_response = result.object
logger.debug(f"Simplified extraction response: {extraction_response}")
# Parse response into structured format

View file

@ -100,10 +100,11 @@ class Processor(FlowProcessor):
try:
rels = await flow("prompt-request").extract_relationships(
result = await flow("prompt-request").extract_relationships(
text = chunk
)
rels = result.objects
logger.debug(f"Prompt response: {rels}")
if type(rels) != list:

View file

@ -148,11 +148,12 @@ class Processor(FlowProcessor):
schema_dict = row_schema_translator.encode(schema)
# Use prompt client to extract rows based on schema
objects = await flow("prompt-request").extract_objects(
result = await flow("prompt-request").extract_objects(
schema=schema_dict,
text=text
)
objects = result.objects
if not isinstance(objects, list):
return []

View file

@ -11,7 +11,6 @@ import logging
from ...schema import Definition, Relationship, Triple
from ...schema import Topic
from ...schema import PromptRequest, PromptResponse, Error
from ...schema import TextCompletionRequest, TextCompletionResponse
from ...base import FlowProcessor
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
@ -124,35 +123,26 @@ class Processor(FlowProcessor):
logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}")
# Use the text completion client with recipient handler
client = flow("text-completion-request")
async def forward_chunks(resp):
if resp.error:
raise RuntimeError(resp.error.message)
is_final = getattr(resp, 'end_of_stream', False)
# Always send a message if there's content OR if it's the final message
if resp.response or is_final:
# Forward each chunk immediately
r = PromptResponse(
text=resp.response if resp.response else "",
object=None,
error=None,
end_of_stream=is_final,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
await flow("response").send(r, properties={"id": id})
# Return True when end_of_stream
return is_final
await client.request(
TextCompletionRequest(
system=system, prompt=prompt, streaming=True
),
recipient=forward_chunks,
timeout=600
await flow("text-completion-request").text_completion_stream(
system=system, prompt=prompt,
handler=forward_chunks,
timeout=600,
)
# Return empty string since we already sent all chunks
@ -167,17 +157,21 @@ class Processor(FlowProcessor):
return
# Non-streaming path (original behavior)
usage = {}
async def llm(system, prompt):
logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}")
resp = await flow("text-completion-request").text_completion(
system = system, prompt = prompt, streaming = False,
)
try:
return resp
result = await flow("text-completion-request").text_completion(
system = system, prompt = prompt,
)
usage["in_token"] = result.in_token
usage["out_token"] = result.out_token
usage["model"] = result.model
return result.text
except Exception as e:
logger.error(f"LLM Exception: {e}", exc_info=True)
return None
@ -199,6 +193,9 @@ class Processor(FlowProcessor):
object=None,
error=None,
end_of_stream=True,
in_token=usage.get("in_token", 0),
out_token=usage.get("out_token", 0),
model=usage.get("model", ""),
)
await flow("response").send(r, properties={"id": id})
@ -215,6 +212,9 @@ class Processor(FlowProcessor):
object=json.dumps(resp),
error=None,
end_of_stream=True,
in_token=usage.get("in_token", 0),
out_token=usage.get("out_token", 0),
model=usage.get("model", ""),
)
await flow("response").send(r, properties={"id": id})

View file

@ -27,24 +27,27 @@ class Query:
def __init__(
self, rag, user, collection, verbose,
doc_limit=20
doc_limit=20, track_usage=None,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
if self.track_usage:
self.track_usage(result)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@ -167,8 +170,23 @@ class DocumentRag:
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns:
str: The synthesized answer text
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose:
logger.debug("Constructing prompt...")
@ -191,7 +209,7 @@ class DocumentRag:
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
doc_limit=doc_limit, track_usage=track_usage,
)
# Extract concepts from query (grounding step)
@ -228,19 +246,22 @@ class DocumentRag:
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
resp = await self.prompt_client.document_prompt(
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.document_prompt(
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
@ -273,5 +294,11 @@ class DocumentRag:
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp
usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage

View file

@ -200,7 +200,7 @@ class Processor(FlowProcessor):
# Query with streaming enabled
# All chunks (including final one with end_of_stream=True) are sent via callback
await self.rag.query(
response, usage = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
@ -217,12 +217,15 @@ class Processor(FlowProcessor):
response=None,
end_of_session=True,
message_type="end",
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
response = await self.rag.query(
# Non-streaming path - single response with answer and token usage
response, usage = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
@ -233,11 +236,15 @@ class Processor(FlowProcessor):
await flow("response").send(
DocumentRagResponse(
response = response,
end_of_stream = True,
error = None
response=response,
end_of_stream=True,
end_of_session=True,
error=None,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties = {"id": id}
properties={"id": id}
)
logger.info("Request processing complete")

View file

@ -121,7 +121,7 @@ class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2,
max_path_length=2, track_usage=None,
):
self.rag = rag
self.user = user
@ -131,17 +131,20 @@ class Query:
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
if self.track_usage:
self.track_usage(result)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@ -609,8 +612,24 @@ class GraphRag:
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
str: The synthesized answer text
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
# Accumulate token usage across all prompt calls
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose:
logger.debug("Constructing prompt...")
@ -641,6 +660,7 @@ class GraphRag:
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
track_usage = track_usage,
)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
@ -751,21 +771,22 @@ class GraphRag:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_response = await self.prompt_client.prompt(
scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
track_usage(scoring_result)
if self.verbose:
logger.debug(f"Edge scoring response: {scoring_response}")
logger.debug(f"Edge scoring result: {scoring_result}")
# Parse scoring response to get edge IDs with scores
# Parse scoring response (jsonl) to get edge IDs with scores
scored_edges = []
def parse_scored_edge(obj):
for obj in scoring_result.objects or []:
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
@ -773,21 +794,6 @@ class GraphRag:
score = 0
scored_edges.append({"id": obj["id"], "score": score})
if isinstance(scoring_response, list):
for obj in scoring_response:
parse_scored_edge(obj)
elif isinstance(scoring_response, str):
for line in scoring_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_scored_edge(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge scoring line: {line}"
)
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
@ -821,25 +827,30 @@ class GraphRag:
]
# Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
async def _get_reasoning():
result = await self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
track_usage(result)
return result
reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather(
reasoning_result, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
)
# Handle exceptions from gather
if isinstance(reasoning_response, Exception):
if isinstance(reasoning_result, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_response}"
f"Edge reasoning failed: {reasoning_result}"
)
reasoning_response = ""
reasoning_result = None
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
@ -848,29 +859,15 @@ class GraphRag:
if self.verbose:
logger.debug(f"Edge reasoning response: {reasoning_response}")
logger.debug(f"Edge reasoning result: {reasoning_result}")
# Parse reasoning response and build explainability data
# Parse reasoning response (jsonl) and build explainability data
reasoning_map = {}
def parse_reasoning(obj):
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
if isinstance(reasoning_response, list):
for obj in reasoning_response:
parse_reasoning(obj)
elif isinstance(reasoning_response, str):
for line in reasoning_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_reasoning(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge reasoning line: {line}"
)
if reasoning_result is not None:
for obj in reasoning_result.objects or []:
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
selected_edges_with_reasoning = []
for eid in selected_ids:
@ -919,19 +916,22 @@ class GraphRag:
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
await self.prompt_client.prompt(
synthesis_result = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
streaming=True,
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.prompt(
synthesis_result = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
)
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
@ -964,5 +964,11 @@ class GraphRag:
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp
usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage

View file

@ -332,7 +332,7 @@ class Processor(FlowProcessor):
)
# Query with streaming and real-time explain
response = await rag.query(
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
@ -348,7 +348,7 @@ class Processor(FlowProcessor):
else:
# Non-streaming path with real-time explain
response = await rag.query(
response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
@ -360,23 +360,30 @@ class Processor(FlowProcessor):
parent_uri = v.parent_uri,
)
# Send chunk with response
# Send single response with answer and token usage
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response=response,
end_of_stream=True,
error=None,
end_of_session=True,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties={"id": id}
)
return
# Send final message to close session
# Streaming: send final message to close session with token usage
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
),
properties={"id": id}
)