Expose LLM token usage across all service layers (#782)

Expose LLM token usage (in_token, out_token, model) across all
service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
cybermaggedon 2026-04-13 14:38:34 +01:00 committed by GitHub
parent 67cfa80836
commit 14e49d83c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 1252 additions and 577 deletions

View file

@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
from trustgraph.agent.react.types import Action, Final, Tool, Argument from trustgraph.agent.react.types import Action, Final, Tool, Argument
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
from trustgraph.base import PromptResult
@pytest.mark.integration @pytest.mark.integration
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
# Mock prompt client # Mock prompt client
prompt_client = AsyncMock() prompt_client = AsyncMock()
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for information about machine learning
Action: knowledge_query Action: knowledge_query
Args: { Args: {
"question": "What is machine learning?" "question": "What is machine learning?"
}""" }"""
)
# Mock graph RAG client # Mock graph RAG client
graph_rag_client = AsyncMock() graph_rag_client = AsyncMock()
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data." graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
# Mock text completion client # Mock text completion client
text_completion_client = AsyncMock() text_completion_client = AsyncMock()
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience." text_completion_client.question.return_value = PromptResult(
response_type="text",
text="Machine learning involves algorithms that improve through experience."
)
# Mock MCP tool client # Mock MCP tool client
mcp_tool_client = AsyncMock() mcp_tool_client = AsyncMock()
@ -147,8 +154,11 @@ Args: {
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context): async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
"""Test agent manager returning final answer""" """Test agent manager returning final answer"""
# Arrange # Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have enough information to answer the question
Final Answer: Machine learning is a field of AI that enables computers to learn from data.""" Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
)
question = "What is machine learning?" question = "What is machine learning?"
history = [] history = []
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context): async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
"""Test ReAct cycle ending with final answer""" """Test ReAct cycle ending with final answer"""
# Arrange # Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide a direct answer
Final Answer: Machine learning is a branch of artificial intelligence.""" Final Answer: Machine learning is a branch of artificial intelligence."""
)
question = "What is machine learning?" question = "What is machine learning?"
history = [] history = []
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
for tool_name, expected_service in tool_scenarios: for tool_name, expected_service in tool_scenarios:
# Arrange # Arrange
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name} mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: I need to use {tool_name}
Action: {tool_name} Action: {tool_name}
Args: {{ Args: {{
"question": "test question" "question": "test question"
}}""" }}"""
)
think_callback = AsyncMock() think_callback = AsyncMock()
observe_callback = AsyncMock() observe_callback = AsyncMock()
@ -284,11 +300,14 @@ Args: {{
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context): async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
"""Test agent manager error handling for unknown tool""" """Test agent manager error handling for unknown tool"""
# Arrange # Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to use an unknown tool
Action: unknown_tool Action: unknown_tool
Args: { Args: {
"param": "value" "param": "value"
}""" }"""
)
think_callback = AsyncMock() think_callback = AsyncMock()
observe_callback = AsyncMock() observe_callback = AsyncMock()
@ -321,11 +340,14 @@ Args: {
question = "Find information about AI and summarize it" question = "Find information about AI and summarize it"
# Mock multi-step reasoning # Mock multi-step reasoning
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for AI information first
Action: knowledge_query Action: knowledge_query
Args: { Args: {
"question": "What is artificial intelligence?" "question": "What is artificial intelligence?"
}""" }"""
)
# Act # Act
action = await agent_manager.reason(question, [], mock_flow_context) action = await agent_manager.reason(question, [], mock_flow_context)
@ -372,9 +394,12 @@ Args: {
# Format arguments as JSON # Format arguments as JSON
import json import json
args_json = json.dumps(test_case['arguments'], indent=4) args_json = json.dumps(test_case['arguments'], indent=4)
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']} mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: Using {test_case['action']}
Action: {test_case['action']} Action: {test_case['action']}
Args: {args_json}""" Args: {args_json}"""
)
think_callback = AsyncMock() think_callback = AsyncMock()
observe_callback = AsyncMock() observe_callback = AsyncMock()
@ -507,7 +532,10 @@ Args: {
] ]
for test_case in test_cases: for test_case in test_cases:
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"] mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=test_case["response"]
)
if test_case["error_contains"]: if test_case["error_contains"]:
# Should raise an error # Should raise an error
@ -527,13 +555,16 @@ Args: {
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context): async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
"""Test edge cases in text parsing""" """Test edge cases in text parsing"""
# Test response with markdown code blocks # Test response with markdown code blocks
mock_flow_context("prompt-request").agent_react.return_value = """``` mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""```
Thought: I need to search for information Thought: I need to search for information
Action: knowledge_query Action: knowledge_query
Args: { Args: {
"question": "What is AI?" "question": "What is AI?"
} }
```""" ```"""
)
action = await agent_manager.reason("test", [], mock_flow_context) action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action) assert isinstance(action, Action)
@ -541,15 +572,18 @@ Args: {
assert action.name == "knowledge_query" assert action.name == "knowledge_query"
# Test response with extra whitespace # Test response with extra whitespace
mock_flow_context("prompt-request").agent_react.return_value = """ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""
Thought: I need to think about this Thought: I need to think about this
Action: knowledge_query Action: knowledge_query
Args: { Args: {
"question": "test" "question": "test"
} }
""" """
)
action = await agent_manager.reason("test", [], mock_flow_context) action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action) assert isinstance(action, Action)
@ -560,7 +594,9 @@ Args: {
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context): async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
"""Test handling of multi-line thoughts and final answers""" """Test handling of multi-line thoughts and final answers"""
# Multi-line thought # Multi-line thought
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors: mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to consider multiple factors:
1. The user's question is complex 1. The user's question is complex
2. I should search for comprehensive information 2. I should search for comprehensive information
3. This requires using the knowledge query tool 3. This requires using the knowledge query tool
@ -568,6 +604,7 @@ Action: knowledge_query
Args: { Args: {
"question": "complex query" "question": "complex query"
}""" }"""
)
action = await agent_manager.reason("test", [], mock_flow_context) action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action) assert isinstance(action, Action)
@ -575,13 +612,16 @@ Args: {
assert "knowledge query tool" in action.thought assert "knowledge query tool" in action.thought
# Multi-line final answer # Multi-line final answer
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have gathered enough information
Final Answer: Here is a comprehensive answer: Final Answer: Here is a comprehensive answer:
1. First point about the topic 1. First point about the topic
2. Second point with details 2. Second point with details
3. Final conclusion 3. Final conclusion
This covers all aspects of the question.""" This covers all aspects of the question."""
)
action = await agent_manager.reason("test", [], mock_flow_context) action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final) assert isinstance(action, Final)
@ -593,13 +633,16 @@ This covers all aspects of the question."""
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context): async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
"""Test JSON arguments with special characters and edge cases""" """Test JSON arguments with special characters and edge cases"""
# Test with special characters in JSON (properly escaped) # Test with special characters in JSON (properly escaped)
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Processing special characters
Action: knowledge_query Action: knowledge_query
Args: { Args: {
"question": "What about \\"quotes\\" and 'apostrophes'?", "question": "What about \\"quotes\\" and 'apostrophes'?",
"context": "Line 1\\nLine 2\\tTabbed", "context": "Line 1\\nLine 2\\tTabbed",
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?" "special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
}""" }"""
)
action = await agent_manager.reason("test", [], mock_flow_context) action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action) assert isinstance(action, Action)
@ -608,7 +651,9 @@ Args: {
assert "@#$%^&*" in action.arguments["special"] assert "@#$%^&*" in action.arguments["special"]
# Test with nested JSON # Test with nested JSON
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Complex arguments
Action: web_search Action: web_search
Args: { Args: {
"query": "test", "query": "test",
@ -621,6 +666,7 @@ Args: {
} }
} }
}""" }"""
)
action = await agent_manager.reason("test", [], mock_flow_context) action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action) assert isinstance(action, Action)
@ -632,7 +678,9 @@ Args: {
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context): async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
"""Test final answers that contain JSON-like content""" """Test final answers that contain JSON-like content"""
# Final answer with JSON content # Final answer with JSON content
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide the data in JSON format
Final Answer: { Final Answer: {
"result": "success", "result": "success",
"data": { "data": {
@ -642,6 +690,7 @@ Final Answer: {
}, },
"confidence": 0.95 "confidence": 0.95
}""" }"""
)
action = await agent_manager.reason("test", [], mock_flow_context) action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final) assert isinstance(action, Final)
@ -792,11 +841,14 @@ Final Answer: {
agent = AgentManager(tools=custom_tools, additional_context="") agent = AgentManager(tools=custom_tools, additional_context="")
# Mock response for custom collection query # Mock response for custom collection query
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search in the research papers
Action: knowledge_query_custom Action: knowledge_query_custom
Args: { Args: {
"question": "Latest AI research?" "question": "Latest AI research?"
}""" }"""
)
think_callback = AsyncMock() think_callback = AsyncMock()
observe_callback = AsyncMock() observe_callback = AsyncMock()

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl from trustgraph.agent.react.tools import KnowledgeQueryImpl
from trustgraph.agent.react.types import Tool, Argument from trustgraph.agent.react.types import Tool, Argument
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import ( from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks, assert_agent_streaming_chunks,
assert_streaming_chunks_valid, assert_streaming_chunks_valid,
@ -51,10 +52,10 @@ Args: {
is_final = (i == len(chunks) - 1) is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final) await chunk_callback(chunk, is_final)
return full_text return PromptResult(response_type="text", text=full_text)
else: else:
# Non-streaming response - same text # Non-streaming response - same text
return full_text return PromptResult(response_type="text", text=full_text)
client.agent_react.side_effect = agent_react_streaming client.agent_react.side_effect = agent_react_streaming
return client return client
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1) is_final = (i == len(chunks) - 1)
await chunk_callback(chunk + " ", is_final) await chunk_callback(chunk + " ", is_final)
return response return PromptResult(response_type="text", text=response)
return response return PromptResult(response_type="text", text=response)
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Error Error
) )
from trustgraph.agent.react.service import Processor from trustgraph.agent.react.service import Processor
from trustgraph.base import PromptResult
@pytest.mark.integration @pytest.mark.integration
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
# Mock the prompt client that agent calls for reasoning # Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from New York using structured query
Action: structured-query Action: structured-query
Args: { Args: {
"question": "Find all customers from New York" "question": "Find all customers from New York"
}""" }"""
)
# Set up flow context routing # Set up flow context routing
def flow_context(service_name): def flow_context(service_name):
@ -173,11 +177,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning # Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query for a table that might not exist
Action: structured-query Action: structured-query
Args: { Args: {
"question": "Find data from a table that doesn't exist" "question": "Find data from a table that doesn't exist"
}""" }"""
)
# Set up flow context routing # Set up flow context routing
def flow_context(service_name): def flow_context(service_name):
@ -250,11 +257,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning # Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from California first
Action: structured-query Action: structured-query
Args: { Args: {
"question": "Find all customers from California" "question": "Find all customers from California"
}""" }"""
)
# Set up flow context routing # Set up flow context routing
def flow_context(service_name): def flow_context(service_name):
@ -339,11 +349,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning # Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query the sales data
Action: structured-query Action: structured-query
Args: { Args: {
"question": "Query the sales data for recent transactions" "question": "Query the sales data for recent transactions"
}""" }"""
)
# Set up flow context routing # Set up flow context routing
def flow_context(service_name): def flow_context(service_name):
@ -447,11 +460,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning # Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to get customer information
Action: structured-query Action: structured-query
Args: { Args: {
"question": "Get customer information and format it nicely" "question": "Get customer information and format it nicely"
}""" }"""
)
# Set up flow context routing # Set up flow context routing
def flow_context(service_name): def flow_context(service_name):

View file

@ -10,6 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content # Sample chunk content for testing - maps chunk_id to content
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
def mock_prompt_client(self): def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses""" """Mock prompt client that generates realistic responses"""
client = AsyncMock() client = AsyncMock()
client.document_prompt.return_value = ( client.document_prompt.return_value = PromptResult(
"Machine learning is a field of artificial intelligence that enables computers to learn " response_type="text",
"and improve from experience without being explicitly programmed. It uses algorithms " text=(
"to find patterns in data and make predictions or decisions." "Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
) )
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client return client
@pytest.fixture @pytest.fixture
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
) )
# Verify final response # Verify final response
result, usage = result
assert result is not None assert result is not None
assert isinstance(result, str) assert isinstance(result, str)
assert "machine learning" in result.lower() assert "machine learning" in result.lower()
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
"""Test DocumentRAG behavior when no documents are retrieved""" """Test DocumentRAG behavior when no documents are retrieved"""
# Arrange # Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query." mock_prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="I couldn't find any relevant documents for your query."
)
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag( document_rag = DocumentRag(
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
documents=[] documents=[]
) )
assert result == "I couldn't find any relevant documents for your query." result_text, usage = result
assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import ( from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid, assert_streaming_chunks_valid,
assert_callback_invoked, assert_callback_invoked,
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
is_final = (i == len(chunks) - 1) is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final) await chunk_callback(chunk, is_final)
return full_text return PromptResult(response_type="text", text=full_text)
else: else:
# Non-streaming response - same text # Non-streaming response - same text
return full_text return PromptResult(response_type="text", text=full_text)
client.document_prompt.side_effect = document_prompt_side_effect client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client return client
@pytest.fixture @pytest.fixture
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
collector.verify_streaming_protocol() collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks # Verify full response matches concatenated chunks
result_text, usage = result
full_from_chunks = collector.get_full_text() full_from_chunks = collector.get_full_text()
assert result == full_from_chunks assert result_text == full_from_chunks
# Verify content is reasonable # Verify content is reasonable
assert len(result) > 0 assert len(result_text) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming): async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
) )
# Assert - Results should be equivalent # Assert - Results should be equivalent
assert streaming_result == non_streaming_result non_streaming_text, _ = non_streaming_result
streaming_text, _ = streaming_result
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0 assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming): async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
) )
# Assert # Assert
result_text, usage = result
assert callback.call_count > 0 assert callback.call_count > 0
assert result is not None assert result_text is not None
# Verify all callback invocations had string arguments # Verify all callback invocations had string arguments
for call in callback.call_args_list: for call in callback.call_args_list:
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
# Assert - Should complete without error # Assert - Should complete without error
assert result is not None assert result is not None
assert isinstance(result, str) result_text, usage = result
assert isinstance(result_text, str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming, async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
) )
# Assert - Should still produce streamed response # Assert - Should still produce streamed response
assert result is not None result_text, usage = result
assert result_text is not None
assert callback.call_count > 0 assert callback.call_count > 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
) )
# Assert # Assert
assert result is not None result_text, usage = result
assert result_text is not None
assert callback.call_count > 0 assert callback.call_count > 0
# Verify doc_limit was passed correctly # Verify doc_limit was passed correctly

View file

@ -12,6 +12,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
@pytest.mark.integration @pytest.mark.integration
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
# 4. kg-synthesis returns the final answer # 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts": if prompt_name == "extract-concepts":
return "" # Falls back to raw query return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring": elif prompt_name == "kg-edge-scoring":
return "" # No edges scored return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning": elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis": elif prompt_name == "kg-synthesis":
return ( return PromptResult(
"Machine learning is a subset of artificial intelligence that enables computers " response_type="text",
"to learn from data without being explicitly programmed. It uses algorithms " text=(
"and statistical models to find patterns in data." "Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
) )
return "" return PromptResult(response_type="text", text="")
client.prompt.side_effect = mock_prompt client.prompt.side_effect = mock_prompt
return client return client
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
assert mock_prompt_client.prompt.call_count == 4 assert mock_prompt_client.prompt.call_count == 4
# Verify final response # Verify final response
response, usage = response
assert response is not None assert response is not None
assert isinstance(response, str) assert isinstance(response, str)
assert "machine learning" in response.lower() assert "machine learning" in response.lower()

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import ( from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid, assert_streaming_chunks_valid,
assert_rag_streaming_chunks, assert_rag_streaming_chunks,
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts": if prompt_id == "extract-concepts":
return "" # Falls back to raw query return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring": elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores # Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n' return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning": elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n' return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis": elif prompt_id == "kg-synthesis":
if streaming and chunk_callback: if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags # Simulate streaming chunks with end_of_stream flags
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
is_final = (i == len(chunks) - 1) is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final) await chunk_callback(chunk, is_final)
return full_text return PromptResult(response_type="text", text=full_text)
else: else:
return full_text return PromptResult(response_type="text", text=full_text)
return "" return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect client.prompt.side_effect = prompt_side_effect
return client return client
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
) )
# Assert # Assert
response, usage = response
assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
) )
# Assert - Results should be equivalent # Assert - Results should be equivalent
assert streaming_response == non_streaming_response non_streaming_text, _ = non_streaming_response
streaming_text, _ = streaming_response
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0 assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_response assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming): async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
# Assert - Should complete without error # Assert - Should complete without error
assert response is not None assert response is not None
assert isinstance(response, str) response_text, usage = response
assert isinstance(response_text, str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming, async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,

View file

@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from trustgraph.base import PromptResult
@pytest.mark.integration @pytest.mark.integration
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
# Mock prompt client for definitions extraction # Mock prompt client for definitions extraction
prompt_client = AsyncMock() prompt_client = AsyncMock()
prompt_client.extract_definitions.return_value = [ prompt_client.extract_definitions.return_value = PromptResult(
{ response_type="jsonl",
"entity": "Machine Learning", objects=[
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." {
}, "entity": "Machine Learning",
{ "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
"entity": "Neural Networks", },
"definition": "Computing systems inspired by biological neural networks that process information." {
} "entity": "Neural Networks",
] "definition": "Computing systems inspired by biological neural networks that process information."
}
]
)
# Mock prompt client for relationships extraction # Mock prompt client for relationships extraction
prompt_client.extract_relationships.return_value = [ prompt_client.extract_relationships.return_value = PromptResult(
{ response_type="jsonl",
"subject": "Machine Learning", objects=[
"predicate": "is_subset_of", {
"object": "Artificial Intelligence", "subject": "Machine Learning",
"object-entity": True "predicate": "is_subset_of",
}, "object": "Artificial Intelligence",
{ "object-entity": True
"subject": "Neural Networks", },
"predicate": "is_used_in", {
"object": "Machine Learning", "subject": "Neural Networks",
"object-entity": True "predicate": "is_used_in",
} "object": "Machine Learning",
] "object-entity": True
}
]
)
# Mock producers for output streams # Mock producers for output streams
triples_producer = AsyncMock() triples_producer = AsyncMock()
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk): async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of empty extraction results""" """Test handling of empty extraction results"""
# Arrange # Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = [] mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[]
)
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk mock_msg.value.return_value = sample_chunk
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk): async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of invalid extraction response format""" """Test handling of invalid extraction response format"""
# Arrange # Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="text",
text="invalid format"
) # Should be jsonl with objects list
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk mock_msg.value.return_value = sample_chunk
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context): async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
"""Test entity filtering and validation in extraction""" """Test entity filtering and validation in extraction"""
# Arrange # Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = [ mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
{"entity": "Valid Entity", "definition": "Valid definition"}, response_type="jsonl",
{"entity": "", "definition": "Empty entity"}, # Should be filtered objects=[
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered {"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": None, "definition": "None entity"}, # Should be filtered {"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered {"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
] {"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
)
sample_chunk = Chunk( sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"), metadata=Metadata(id="test", user="user", collection="collection"),

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field, Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse PromptRequest, PromptResponse
) )
from trustgraph.base import PromptResult
@pytest.mark.integration @pytest.mark.integration
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
if schema_name == "customer_records": if schema_name == "customer_records":
if "john" in text.lower(): if "john" in text.lower():
return [ return PromptResult(
{ response_type="jsonl",
"customer_id": "CUST001", objects=[
"name": "John Smith", {
"email": "john.smith@email.com", "customer_id": "CUST001",
"phone": "555-0123" "name": "John Smith",
} "email": "john.smith@email.com",
] "phone": "555-0123"
}
]
)
elif "jane" in text.lower(): elif "jane" in text.lower():
return [ return PromptResult(
{ response_type="jsonl",
"customer_id": "CUST002", objects=[
"name": "Jane Doe", {
"email": "jane.doe@email.com", "customer_id": "CUST002",
"phone": "" "name": "Jane Doe",
} "email": "jane.doe@email.com",
] "phone": ""
}
]
)
else: else:
return [] return PromptResult(response_type="jsonl", objects=[])
elif schema_name == "product_catalog": elif schema_name == "product_catalog":
if "laptop" in text.lower(): if "laptop" in text.lower():
return [ return PromptResult(
{ response_type="jsonl",
"product_id": "PROD001", objects=[
"name": "Gaming Laptop", {
"price": "1299.99", "product_id": "PROD001",
"category": "electronics" "name": "Gaming Laptop",
} "price": "1299.99",
] "category": "electronics"
}
]
)
elif "book" in text.lower(): elif "book" in text.lower():
return [ return PromptResult(
{ response_type="jsonl",
"product_id": "PROD002", objects=[
"name": "Python Programming Guide", {
"price": "49.99", "product_id": "PROD002",
"category": "books" "name": "Python Programming Guide",
} "price": "49.99",
] "category": "books"
}
]
)
else: else:
return [] return PromptResult(response_type="jsonl", objects=[])
return [] return PromptResult(response_type="jsonl", objects=[])
prompt_client.extract_objects.side_effect = mock_extract_objects prompt_client.extract_objects.side_effect = mock_extract_objects

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from trustgraph.prompt.template.service import Processor from trustgraph.prompt.template.service import Processor
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
from trustgraph.base.text_completion_client import TextCompletionResult
from tests.utils.streaming_assertions import ( from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid, assert_streaming_chunks_valid,
assert_callback_invoked, assert_callback_invoked,
@ -27,34 +28,52 @@ class TestPromptStreaming:
# Mock text completion client with streaming # Mock text completion client with streaming
text_completion_client = AsyncMock() text_completion_client = AsyncMock()
async def streaming_request(request, recipient=None, timeout=600): # Streaming chunks to send
"""Simulate streaming text completion""" chunks = [
if request.streaming and recipient: "Machine", " learning", " is", " a", " field",
# Simulate streaming chunks " of", " artificial", " intelligence", "."
chunks = [ ]
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
for i, chunk_text in enumerate(chunks): async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
is_final = (i == len(chunks) - 1) """Simulate streaming text completion via text_completion_stream"""
response = TextCompletionResponse( for i, chunk_text in enumerate(chunks):
response=chunk_text, response = TextCompletionResponse(
error=None, response=chunk_text,
end_of_stream=is_final
)
final = await recipient(response)
if final:
break
# Final empty chunk
await recipient(TextCompletionResponse(
response="",
error=None, error=None,
end_of_stream=True end_of_stream=False
)) )
await handler(response)
text_completion_client.request = streaming_request # Send final empty chunk with end_of_stream
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
return TextCompletionResult(
text=None,
in_token=10,
out_token=9,
model="test-model",
)
async def non_streaming_text_completion(system, prompt, timeout=600):
"""Simulate non-streaming text completion"""
full_text = "Machine learning is a field of artificial intelligence."
return TextCompletionResult(
text=full_text,
in_token=10,
out_token=9,
model="test-model",
)
text_completion_client.text_completion_stream = AsyncMock(
side_effect=streaming_text_completion_stream
)
text_completion_client.text_completion = AsyncMock(
side_effect=non_streaming_text_completion
)
# Mock response producer # Mock response producer
response_producer = AsyncMock() response_producer = AsyncMock()
@ -156,14 +175,6 @@ class TestPromptStreaming:
consumer = MagicMock() consumer = MagicMock()
# Mock non-streaming text completion
text_completion_client = mock_flow_context_streaming("text-completion-request")
async def non_streaming_text_completion(system, prompt, streaming=False):
return "AI is the simulation of human intelligence in machines."
text_completion_client.text_completion = non_streaming_text_completion
# Act # Act
await prompt_processor_streaming.on_request( await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming message, consumer, mock_flow_context_streaming
@ -218,17 +229,12 @@ class TestPromptStreaming:
# Mock text completion client that raises an error # Mock text completion client that raises an error
text_completion_client = AsyncMock() text_completion_client = AsyncMock()
async def failing_request(request, recipient=None, timeout=600): async def failing_stream(system, prompt, handler, timeout=600):
if recipient: raise RuntimeError("Text completion error")
# Send error response with proper Error schema
error_response = TextCompletionResponse(
response="",
error=Error(message="Text completion error", type="processing_error"),
end_of_stream=True
)
await recipient(error_response)
text_completion_client.request = failing_request text_completion_client.text_completion_stream = AsyncMock(
side_effect=failing_stream
)
# Mock response producer to capture error response # Mock response producer to capture error response
response_producer = AsyncMock() response_producer = AsyncMock()
@ -255,22 +261,15 @@ class TestPromptStreaming:
consumer = MagicMock() consumer = MagicMock()
# Act - The service catches errors and sends error responses, doesn't raise # Act - The service catches errors and sends an error PromptResponse
await prompt_processor_streaming.on_request(message, consumer, context) await prompt_processor_streaming.on_request(message, consumer, context)
# Assert - Verify error response was sent # Assert - error response was sent
assert response_producer.send.call_count > 0 calls = response_producer.send.call_args_list
assert len(calls) > 0
# Check that at least one response contains an error error_response = calls[-1].args[0]
error_sent = False assert error_response.error is not None
for call in response_producer.send.call_args_list: assert "Text completion error" in error_response.error.message
response = call.args[0]
if hasattr(response, 'error') and response.error:
error_sent = True
assert "Text completion error" in response.error.message
break
assert error_sent, "Expected error response to be sent"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming, async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
@ -315,21 +314,22 @@ class TestPromptStreaming:
# Mock text completion that sends empty chunks # Mock text completion that sends empty chunks
text_completion_client = AsyncMock() text_completion_client = AsyncMock()
async def empty_streaming_request(request, recipient=None, timeout=600): async def empty_streaming(system, prompt, handler, timeout=600):
if request.streaming and recipient: # Send empty chunk followed by final marker
# Send empty chunk followed by final marker await handler(TextCompletionResponse(
await recipient(TextCompletionResponse( response="",
response="", error=None,
error=None, end_of_stream=False
end_of_stream=False ))
)) await handler(TextCompletionResponse(
await recipient(TextCompletionResponse( response="",
response="", error=None,
error=None, end_of_stream=True
end_of_stream=True ))
))
text_completion_client.request = empty_streaming_request text_completion_client.text_completion_stream = AsyncMock(
side_effect=empty_streaming
)
response_producer = AsyncMock() response_producer = AsyncMock()
def context_router(service_name): def context_router(service_name):
@ -401,4 +401,4 @@ class TestPromptStreaming:
# Verify chunks concatenate to expected result # Verify chunks concatenate to expected result
full_text = "".join(chunk_texts) full_text = "".join(chunk_texts)
assert full_text == "Machine learning is a field of artificial intelligence" assert full_text == "Machine learning is a field of artificial intelligence."

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
from trustgraph.base import PromptResult
class TestGraphRagStreamingProtocol: class TestGraphRagStreamingProtocol:
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection": if prompt_name == "kg-edge-selection":
# Edge selection returns empty (no edges selected) return PromptResult(response_type="text", text="")
return ""
elif prompt_name == "kg-synthesis": elif prompt_name == "kg-synthesis":
if streaming and chunk_callback: if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
await chunk_callback(" answer", False) await chunk_callback(" answer", False)
await chunk_callback(" is here.", False) await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything return PromptResult(response_type="text", text="")
else: else:
return "The answer is here." return PromptResult(response_type="text", text="The answer is here.")
return "" return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect client.prompt.side_effect = prompt_side_effect
return client return client
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
await chunk_callback("Document", False) await chunk_callback("Document", False)
await chunk_callback(" summary", False) await chunk_callback(" summary", False)
await chunk_callback(".", True) # Non-empty final chunk await chunk_callback(".", True) # Non-empty final chunk
return "" return PromptResult(response_type="text", text="")
else: else:
return "Document summary." return PromptResult(response_type="text", text="Document summary.")
client.document_prompt.side_effect = document_prompt_side_effect client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client return client
@pytest.fixture @pytest.fixture
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection": if prompt_name == "kg-edge-selection":
return "" return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis": elif prompt_name == "kg-synthesis":
if streaming and chunk_callback: if streaming and chunk_callback:
await chunk_callback("text", False) await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False) await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final await chunk_callback("", True) # Empty and final
return "" return PromptResult(response_type="text", text="")
else: else:
return "textmore" return PromptResult(response_type="text", text="textmore")
return "" return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_with_empties client.prompt.side_effect = prompt_with_empties

View file

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import ( from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE, MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
) )
from trustgraph.base import PromptResult
def _make_config(patterns=None, task_types=None): def _make_config(patterns=None, task_types=None):
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
def _make_context(prompt_response): def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client.""" """Build a mock context that returns a mock prompt client."""
client = AsyncMock() client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response) client.prompt = AsyncMock(
return_value=PromptResult(response_type="text", text=prompt_response)
)
def context(service_name): def context(service_name):
return client return client
@ -274,8 +277,8 @@ class TestRoute:
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
if call_count == 1: if call_count == 1:
return "research" # task type return PromptResult(response_type="text", text="research")
return "plan-then-execute" # pattern return PromptResult(response_type="text", text="plan-then-execute")
client.prompt = mock_prompt client.prompt = mock_prompt
context = lambda name: client context = lambda name: client

View file

@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from trustgraph.schema import ( from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep, AgentRequest, AgentResponse, AgentStep, PlanStep,
) )
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import ( from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
) )
async def mock_react(question, history, think, observe, answer, async def mock_react(question, history, think, observe, answer,
context, streaming, on_action): context, streaming, on_action, **kwargs):
# Simulate the on_action callback before returning Final # Simulate the on_action callback before returning Final
if on_action: if on_action:
await on_action(Action( await on_action(Action(
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer, async def mock_react(question, history, think, observe, answer,
context, streaming, on_action): context, streaming, on_action, **kwargs):
if on_action: if on_action:
await on_action(action) await on_action(action)
return action return action
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer, async def mock_react(question, history, think, observe, answer,
context, streaming, on_action): context, streaming, on_action, **kwargs):
if on_action: if on_action:
await on_action(Action( await on_action(Action(
thought="done", name="final", thought="done", name="final",
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
# Mock prompt client for plan creation # Mock prompt client for plan creation
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [ mock_prompt_client.prompt.return_value = PromptResult(
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []}, response_type="jsonl",
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]}, objects=[
] {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
],
)
def flow_factory(name): def flow_factory(name):
if name == "prompt-request": if name == "prompt-request":
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
# Mock prompt for step execution # Mock prompt for step execution
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = { mock_prompt_client.prompt.return_value = PromptResult(
"tool": "knowledge-query", response_type="json",
"arguments": {"question": "quantum computing"}, object={
} "tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
},
)
def flow_factory(name): def flow_factory(name):
if name == "prompt-request": if name == "prompt-request":
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
# Mock prompt for synthesis # Mock prompt for synthesis
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer." mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
def flow_factory(name): def flow_factory(name):
if name == "prompt-request": if name == "prompt-request":
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
# Mock prompt for decomposition # Mock prompt for decomposition
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [ mock_prompt_client.prompt.return_value = PromptResult(
"What is quantum computing?", response_type="jsonl",
"What are qubits?", objects=[
] "What is quantum computing?",
"What are qubits?",
],
)
def flow_factory(name): def flow_factory(name):
if name == "prompt-request": if name == "prompt-request":
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
# Mock prompt for synthesis # Mock prompt for synthesis
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer." mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
def flow_factory(name): def flow_factory(name):
if name == "prompt-request": if name == "prompt-request":
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
flow = make_mock_flow() flow = make_mock_flow()
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"] mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B", "Goal C"],
)
def flow_factory(name): def flow_factory(name):
if name == "prompt-request": if name == "prompt-request":

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import ( from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size, Processor, default_triples_batch_size, default_entity_batch_size,
) )
from trustgraph.base import PromptResult
from trustgraph.schema import ( from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL, Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
) )
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock() mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock() mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
if isinstance(prompt_result, list):
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
else:
wrapped = PromptResult(response_type="text", text=prompt_result)
mock_prompt_client.extract_definitions = AsyncMock( mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result return_value=wrapped
) )
def flow(name): def flow(name):

View file

@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
from trustgraph.schema import ( from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL, Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
) )
from trustgraph.base import PromptResult
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock() mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock( mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result return_value=PromptResult(
response_type="jsonl",
objects=prompt_result,
)
) )
def flow(name): def flow(name):

View file

@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, AsyncMock from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
from trustgraph.base import PromptResult
# Sample chunk content mapping for tests # Sample chunk content mapping for tests
@ -132,7 +133,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client mock_rag.prompt_client = mock_prompt_client
# Mock the prompt response with concept lines # Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
@ -157,7 +158,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client mock_rag.prompt_client = mock_prompt_client
# Mock empty response # Mock empty response
mock_prompt_client.prompt.return_value = "" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
@ -258,7 +259,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction # Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
# Mock embeddings - one vector per concept # Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
@ -273,7 +274,7 @@ class TestQuery:
expected_response = "This is the document RAG response" expected_response = "This is the document RAG response"
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
@ -315,7 +316,8 @@ class TestQuery:
assert "Relevant document content" in docs assert "Relevant document content" in docs
assert "Another document" in docs assert "Another document" in docs
assert result == expected_response result_text, usage = result
assert result_text == expected_response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk): async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
@ -325,7 +327,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction fallback (empty → raw query) # Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = "" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# Mock responses # Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
@ -333,7 +335,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c5" mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9 mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match] mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response" mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
@ -352,7 +354,8 @@ class TestQuery:
collection="default" # Default collection collection="default" # Default collection
) )
assert result == "Default response" result_text, usage = result
assert result_text == "Default response"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self): async def test_get_docs_with_verbose_output(self):
@ -401,7 +404,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction # Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
# Mock responses # Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
@ -409,7 +412,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c7" mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92 mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match] mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response" mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
@ -428,7 +431,8 @@ class TestQuery:
assert call_args.kwargs["query"] == "verbose query test" assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"] assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response" result_text, usage = result
assert result_text == "Verbose RAG response"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_docs_with_empty_results(self): async def test_get_docs_with_empty_results(self):
@ -469,11 +473,11 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction # Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]] mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = [] mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response" mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
@ -490,7 +494,8 @@ class TestQuery:
documents=[] documents=[]
) )
assert result == "No documents found response" result_text, usage = result
assert result_text == "No documents found response"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_vectors_with_verbose(self): async def test_get_vectors_with_verbose(self):
@ -525,7 +530,7 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
# Mock concept extraction # Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
# Mock embeddings - one vector per concept # Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]] query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
@ -541,7 +546,7 @@ class TestQuery:
MagicMock(chunk_id="doc/ml3", score=0.82), MagicMock(chunk_id="doc/ml3", score=0.82),
] ]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2] mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
@ -584,7 +589,8 @@ class TestQuery:
assert "Common ML techniques include supervised and unsupervised learning..." in docs assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated assert len(docs) == 3 # doc/ml2 deduplicated
assert result == final_response result_text, usage = result
assert result_text == final_response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self): async def test_get_docs_deduplicates_across_concepts(self):

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
from dataclasses import dataclass from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import ( from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -89,8 +90,8 @@ def build_mock_clients():
# 1. Concept extraction # 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs): async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts": if template_id == "extract-concepts":
return "return policy\nrefund" return PromptResult(response_type="text", text="return policy\nrefund")
return "" return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt prompt_client.prompt.side_effect = mock_prompt
@ -113,8 +114,9 @@ def build_mock_clients():
fetch_chunk.side_effect = mock_fetch fetch_chunk.side_effect = mock_fetch
# 5. Synthesis # 5. Synthesis
prompt_client.document_prompt.return_value = ( prompt_client.document_prompt.return_value = PromptResult(
"Items can be returned within 30 days for a full refund." response_type="text",
text="Items can be returned within 30 days for a full refund.",
) )
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients() clients = build_mock_clients()
rag = DocumentRag(*clients) rag = DocumentRag(*clients)
result = await rag.query( result_text, usage = await rag.query(
query="What is the return policy?", query="What is the return policy?",
explain_callback=AsyncMock(), explain_callback=AsyncMock(),
) )
assert result == "Items can be returned within 30 days for a full refund." assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_explain_callback_still_works(self): async def test_no_explain_callback_still_works(self):
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients() clients = build_mock_clients()
rag = DocumentRag(*clients) rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?") result_text, usage = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund." assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self): async def test_all_triples_in_retrieval_graph(self):

View file

@ -34,7 +34,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance # Setup mock DocumentRag instance
mock_rag_instance = AsyncMock() mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "test response" mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom user/collection # Setup message with custom user/collection
msg = MagicMock() msg = MagicMock()
@ -97,7 +97,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance # Setup mock DocumentRag instance
mock_rag_instance = AsyncMock() mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A document about cats." mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
# Setup message with non-streaming request # Setup message with non-streaming request
msg = MagicMock() msg = MagicMock()
@ -130,4 +130,5 @@ class TestDocumentRagService:
assert isinstance(sent_response, DocumentRagResponse) assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats." assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True" assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.end_of_session is True
assert sent_response.error is None assert sent_response.error is None

View file

@ -7,6 +7,7 @@ import unittest.mock
from unittest.mock import MagicMock, AsyncMock from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
from trustgraph.base import PromptResult
class TestGraphRag: class TestGraphRag:
@ -172,7 +173,7 @@ class TestQuery:
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
@ -196,7 +197,7 @@ class TestQuery:
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query( query = Query(
rag=mock_rag, rag=mock_rag,
@ -220,7 +221,7 @@ class TestQuery:
mock_rag.graph_embeddings_client = mock_graph_embeddings_client mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# extract_concepts returns empty -> falls back to [query] # extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = "" mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# embed returns one vector set for the single concept # embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
@ -565,14 +566,14 @@ class TestQuery:
# Mock prompt responses for the multi-step process # Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts": if prompt_name == "extract-concepts":
return "" # Falls back to raw query return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring": elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9}) return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning": elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"}) return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
elif prompt_name == "kg-synthesis": elif prompt_name == "kg-synthesis":
return expected_response return PromptResult(response_type="text", text=expected_response)
return "" return PromptResult(response_type="text", text="")
mock_prompt_client.prompt = mock_prompt mock_prompt_client.prompt = mock_prompt
@ -607,7 +608,8 @@ class TestQuery:
explain_callback=collect_provenance explain_callback=collect_provenance
) )
assert response == expected_response response_text, usage = response
assert response_text == expected_response
# 5 events: question, grounding, exploration, focus, synthesis # 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5 assert len(provenance_events) == 5

View file

@ -13,6 +13,7 @@ from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import ( from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -136,24 +137,36 @@ def build_mock_clients():
async def mock_prompt(template_id, variables=None, **kwargs): async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts": if template_id == "extract-concepts":
return prompt_responses["extract-concepts"] return PromptResult(
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring": elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed # Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", []) edges = variables.get("knowledge", [])
return [ return PromptResult(
{"id": e["id"], "score": 10 - i} response_type="jsonl",
for i, e in enumerate(edges) objects=[
] {"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning": elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge # Provide reasoning for each edge
edges = variables.get("knowledge", []) edges = variables.get("knowledge", [])
return [ return PromptResult(
{"id": e["id"], "reasoning": f"Relevant edge {i}"} response_type="jsonl",
for i, e in enumerate(edges) objects=[
] {"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis": elif template_id == "kg-synthesis":
return synthesis_answer return PromptResult(
return "" response_type="text",
text=synthesis_answer,
)
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt prompt_client.prompt.side_effect = mock_prompt
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
async def explain_callback(triples, explain_id): async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id}) events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query( result_text, usage = await rag.query(
query="What is quantum computing?", query="What is quantum computing?",
explain_callback=explain_callback, explain_callback=explain_callback,
edge_score_limit=0, edge_score_limit=0,
) )
assert result == "Quantum computing applies physics principles to computation." assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self): async def test_parent_uri_links_question_to_parent(self):
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
clients = build_mock_clients() clients = build_mock_clients()
rag = GraphRag(*clients) rag = GraphRag(*clients)
result = await rag.query( result_text, usage = await rag.query(
query="What is quantum computing?", query="What is quantum computing?",
edge_score_limit=0, edge_score_limit=0,
) )
assert result == "Quantum computing applies physics principles to computation." assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self): async def test_all_triples_in_retrieval_graph(self):

View file

@ -44,7 +44,7 @@ class TestGraphRagService:
await explain_callback([], "urn:trustgraph:prov:retrieval:test") await explain_callback([], "urn:trustgraph:prov:retrieval:test")
await explain_callback([], "urn:trustgraph:prov:selection:test") await explain_callback([], "urn:trustgraph:prov:selection:test")
await explain_callback([], "urn:trustgraph:prov:answer:test") await explain_callback([], "urn:trustgraph:prov:answer:test")
return "A small domesticated mammal." return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query mock_rag_instance.query.side_effect = mock_query
@ -79,8 +79,8 @@ class TestGraphRagService:
# Execute # Execute
await processor.on_request(msg, consumer, flow) await processor.on_request(msg, consumer, flow)
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session) # Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
assert mock_response_producer.send.call_count == 6 assert mock_response_producer.send.call_count == 5
# First 4 messages are explain (emitted in real-time during query) # First 4 messages are explain (emitted in real-time during query)
for i in range(4): for i in range(4):
@ -88,17 +88,12 @@ class TestGraphRagService:
assert prov_msg.message_type == "explain" assert prov_msg.message_type == "explain"
assert prov_msg.explain_id is not None assert prov_msg.explain_id is not None
# 5th message is chunk with response # 5th message is chunk with response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[4][0][0] chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
assert chunk_msg.message_type == "chunk" assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "A small domesticated mammal." assert chunk_msg.response == "A small domesticated mammal."
assert chunk_msg.end_of_stream is True assert chunk_msg.end_of_stream is True
assert chunk_msg.end_of_session is True
# 6th message is empty chunk with end_of_session=True
close_msg = mock_response_producer.send.call_args_list[5][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
# Verify provenance triples were sent to provenance queue # Verify provenance triples were sent to provenance queue
assert mock_provenance_producer.send.call_count == 4 assert mock_provenance_producer.send.call_count == 4
@ -187,7 +182,7 @@ class TestGraphRagService:
async def mock_query(**kwargs): async def mock_query(**kwargs):
# Don't call explain_callback # Don't call explain_callback
return "Response text" return "Response text", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query mock_rag_instance.query.side_effect = mock_query
@ -218,17 +213,12 @@ class TestGraphRagService:
# Execute # Execute
await processor.on_request(msg, consumer, flow) await processor.on_request(msg, consumer, flow)
# Verify: 2 messages (chunk + empty chunk to close) # Verify: 1 combined message (chunk with end_of_session)
assert mock_response_producer.send.call_count == 2 assert mock_response_producer.send.call_count == 1
# First is the response chunk # Single message has response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[0][0][0] chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
assert chunk_msg.message_type == "chunk" assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "Response text" assert chunk_msg.response == "Response text"
assert chunk_msg.end_of_stream is True assert chunk_msg.end_of_stream is True
assert chunk_msg.end_of_session is True
# Second is empty chunk to close session
close_msg = mock_response_producer.send.call_args_list[1][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True

View file

@ -107,6 +107,7 @@ from .types import (
AgentObservation, AgentObservation,
AgentAnswer, AgentAnswer,
RAGChunk, RAGChunk,
TextCompletionResult,
ProvenanceEvent, ProvenanceEvent,
) )
@ -185,6 +186,7 @@ __all__ = [
"AgentObservation", "AgentObservation",
"AgentAnswer", "AgentAnswer",
"RAGChunk", "RAGChunk",
"TextCompletionResult",
"ProvenanceEvent", "ProvenanceEvent",
# Exceptions # Exceptions

View file

@ -14,6 +14,8 @@ import aiohttp
import json import json
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from . types import TextCompletionResult
from . exceptions import ProtocolException, ApplicationException from . exceptions import ProtocolException, ApplicationException
@ -434,12 +436,11 @@ class AsyncFlowInstance:
return await self.request("agent", request_data) return await self.request("agent", request_data)
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str: async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult:
""" """
Generate text completion (non-streaming). Generate text completion (non-streaming).
Generates a text response from an LLM given a system prompt and user prompt. Generates a text response from an LLM given a system prompt and user prompt.
Returns the complete response text.
Note: This method does not support streaming. For streaming text generation, Note: This method does not support streaming. For streaming text generation,
use AsyncSocketFlowInstance.text_completion() instead. use AsyncSocketFlowInstance.text_completion() instead.
@ -450,19 +451,19 @@ class AsyncFlowInstance:
**kwargs: Additional service-specific parameters **kwargs: Additional service-specific parameters
Returns: Returns:
str: Complete generated text response TextCompletionResult: Result with text, in_token, out_token, model
Example: Example:
```python ```python
async_flow = await api.async_flow() async_flow = await api.async_flow()
flow = async_flow.id("default") flow = async_flow.id("default")
# Generate text result = await flow.text_completion(
response = await flow.text_completion(
system="You are a helpful assistant.", system="You are a helpful assistant.",
prompt="Explain quantum computing in simple terms." prompt="Explain quantum computing in simple terms."
) )
print(response) print(result.text)
print(f"Tokens: {result.in_token} in, {result.out_token} out")
``` ```
""" """
request_data = { request_data = {
@ -473,7 +474,12 @@ class AsyncFlowInstance:
request_data.update(kwargs) request_data.update(kwargs)
result = await self.request("text-completion", request_data) result = await self.request("text-completion", request_data)
return result.get("response", "") return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
async def graph_rag(self, query: str, user: str, collection: str, async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_subgraph_size: int = 1000, max_subgraph_count: int = 5,

View file

@ -4,7 +4,7 @@ import asyncio
import websockets import websockets
from typing import Optional, Dict, Any, AsyncIterator, Union from typing import Optional, Dict, Any, AsyncIterator, Union
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult
from . exceptions import ProtocolException, ApplicationException from . exceptions import ProtocolException, ApplicationException
@ -199,7 +199,10 @@ class AsyncSocketClient:
return AgentAnswer( return AgentAnswer(
content=resp.get("content", ""), content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False) end_of_dialog=resp.get("end_of_dialog", False),
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
) )
elif chunk_type == "action": elif chunk_type == "action":
return AgentThought( return AgentThought(
@ -211,7 +214,10 @@ class AsyncSocketClient:
return RAGChunk( return RAGChunk(
content=content, content=content,
end_of_stream=resp.get("end_of_stream", False), end_of_stream=resp.get("end_of_stream", False),
error=None error=None,
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
) )
async def aclose(self): async def aclose(self):
@ -269,7 +275,11 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("agent", self.flow_id, request) return await self.client._send_request("agent", self.flow_id, request)
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs): async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
"""Text completion with optional streaming""" """Text completion with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk).
"""
request = { request = {
"system": system, "system": system,
"prompt": prompt, "prompt": prompt,
@ -281,13 +291,18 @@ class AsyncSocketFlowInstance:
return self._text_completion_streaming(request) return self._text_completion_streaming(request)
else: else:
result = await self.client._send_request("text-completion", self.flow_id, request) result = await self.client._send_request("text-completion", self.flow_id, request)
return result.get("response", "") return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
async def _text_completion_streaming(self, request): async def _text_completion_streaming(self, request):
"""Helper for streaming text completion""" """Helper for streaming text completion. Yields RAGChunk objects."""
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request): async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
if hasattr(chunk, 'content'): if isinstance(chunk, RAGChunk):
yield chunk.content yield chunk
async def graph_rag(self, query: str, user: str, collection: str, async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5, max_subgraph_size: int = 1000, max_subgraph_count: int = 5,

View file

@ -11,7 +11,7 @@ import base64
from .. knowledge import hash, Uri, Literal, QuotedTriple from .. knowledge import hash, Uri, Literal, QuotedTriple
from .. schema import IRI, LITERAL, TRIPLE from .. schema import IRI, LITERAL, TRIPLE
from . types import Triple from . types import Triple, TextCompletionResult
from . exceptions import ProtocolException from . exceptions import ProtocolException
@ -360,16 +360,17 @@ class FlowInstance:
prompt: User prompt/question prompt: User prompt/question
Returns: Returns:
str: Generated response text TextCompletionResult: Result with text, in_token, out_token, model
Example: Example:
```python ```python
flow = api.flow().id("default") flow = api.flow().id("default")
response = flow.text_completion( result = flow.text_completion(
system="You are a helpful assistant", system="You are a helpful assistant",
prompt="What is quantum computing?" prompt="What is quantum computing?"
) )
print(response) print(result.text)
print(f"Tokens: {result.in_token} in, {result.out_token} out")
``` ```
""" """
@ -379,10 +380,17 @@ class FlowInstance:
"prompt": prompt "prompt": prompt
} }
return self.request( result = self.request(
"service/text-completion", "service/text-completion",
input input
)["response"] )
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def agent(self, question, user="trustgraph", state=None, group=None, history=None): def agent(self, question, user="trustgraph", state=None, group=None, history=None):
""" """
@ -498,10 +506,17 @@ class FlowInstance:
"edge-limit": edge_limit, "edge-limit": edge_limit,
} }
return self.request( result = self.request(
"service/graph-rag", "service/graph-rag",
input input
)["response"] )
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def document_rag( def document_rag(
self, query, user="trustgraph", collection="default", self, query, user="trustgraph", collection="default",
@ -543,10 +558,17 @@ class FlowInstance:
"doc-limit": doc_limit, "doc-limit": doc_limit,
} }
return self.request( result = self.request(
"service/document-rag", "service/document-rag",
input input
)["response"] )
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def embeddings(self, texts): def embeddings(self, texts):
""" """

View file

@ -14,7 +14,7 @@ import websockets
from typing import Optional, Dict, Any, Iterator, Union, List from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock from threading import Lock
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult
from . exceptions import ProtocolException, raise_from_error_dict from . exceptions import ProtocolException, raise_from_error_dict
@ -393,6 +393,9 @@ class SocketClient:
end_of_message=resp.get("end_of_message", False), end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False), end_of_dialog=resp.get("end_of_dialog", False),
message_id=resp.get("message_id", ""), message_id=resp.get("message_id", ""),
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
) )
elif chunk_type == "action": elif chunk_type == "action":
return AgentThought( return AgentThought(
@ -404,7 +407,10 @@ class SocketClient:
return RAGChunk( return RAGChunk(
content=content, content=content,
end_of_stream=resp.get("end_of_stream", False), end_of_stream=resp.get("end_of_stream", False),
error=None error=None,
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
) )
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent: def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
@ -543,8 +549,12 @@ class SocketFlowInstance:
streaming=True, include_provenance=True streaming=True, include_provenance=True
) )
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute text completion with optional streaming.""" """Execute text completion with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = { request = {
"system": system, "system": system,
"prompt": prompt, "prompt": prompt,
@ -557,12 +567,17 @@ class SocketFlowInstance:
if streaming: if streaming:
return self._text_completion_generator(result) return self._text_completion_generator(result)
else: else:
return result.get("response", "") return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result: for chunk in result:
if hasattr(chunk, 'content'): if isinstance(chunk, RAGChunk):
yield chunk.content yield chunk
def graph_rag( def graph_rag(
self, self,
@ -577,8 +592,12 @@ class SocketFlowInstance:
edge_limit: int = 25, edge_limit: int = 25,
streaming: bool = False, streaming: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[str, Iterator[str]]: ) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute graph-based RAG query with optional streaming.""" """Execute graph-based RAG query with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = { request = {
"query": query, "query": query,
"user": user, "user": user,
@ -598,7 +617,12 @@ class SocketFlowInstance:
if streaming: if streaming:
return self._rag_generator(result) return self._rag_generator(result)
else: else:
return result.get("response", "") return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def graph_rag_explain( def graph_rag_explain(
self, self,
@ -642,8 +666,12 @@ class SocketFlowInstance:
doc_limit: int = 10, doc_limit: int = 10,
streaming: bool = False, streaming: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[str, Iterator[str]]: ) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute document-based RAG query with optional streaming.""" """Execute document-based RAG query with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = { request = {
"query": query, "query": query,
"user": user, "user": user,
@ -658,7 +686,12 @@ class SocketFlowInstance:
if streaming: if streaming:
return self._rag_generator(result) return self._rag_generator(result)
else: else:
return result.get("response", "") return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def document_rag_explain( def document_rag_explain(
self, self,
@ -684,10 +717,10 @@ class SocketFlowInstance:
streaming=True, include_provenance=True streaming=True, include_provenance=True
) )
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result: for chunk in result:
if hasattr(chunk, 'content'): if isinstance(chunk, RAGChunk):
yield chunk.content yield chunk
def prompt( def prompt(
self, self,
@ -695,8 +728,12 @@ class SocketFlowInstance:
variables: Dict[str, str], variables: Dict[str, str],
streaming: bool = False, streaming: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[str, Iterator[str]]: ) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute a prompt template with optional streaming.""" """Execute a prompt template with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = { request = {
"id": id, "id": id,
"variables": variables, "variables": variables,
@ -709,7 +746,12 @@ class SocketFlowInstance:
if streaming: if streaming:
return self._rag_generator(result) return self._rag_generator(result)
else: else:
return result.get("response", "") return TextCompletionResult(
text=result.get("text", result.get("response", "")),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def graph_embeddings_query( def graph_embeddings_query(
self, self,

View file

@ -189,6 +189,9 @@ class AgentAnswer(StreamingChunk):
chunk_type: str = "final-answer" chunk_type: str = "final-answer"
end_of_dialog: bool = False end_of_dialog: bool = False
message_id: str = "" message_id: str = ""
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass @dataclasses.dataclass
class RAGChunk(StreamingChunk): class RAGChunk(StreamingChunk):
@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk):
content: Generated text content content: Generated text content
end_of_stream: True if this is the final chunk of the stream end_of_stream: True if this is the final chunk of the stream
error: Optional error information if an error occurred error: Optional error information if an error occurred
in_token: Input token count (populated on the final chunk, 0 otherwise)
out_token: Output token count (populated on the final chunk, 0 otherwise)
model: Model identifier (populated on the final chunk, empty otherwise)
chunk_type: Always "rag" chunk_type: Always "rag"
""" """
chunk_type: str = "rag" chunk_type: str = "rag"
end_of_stream: bool = False end_of_stream: bool = False
error: Optional[Dict[str, str]] = None error: Optional[Dict[str, str]] = None
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass
class TextCompletionResult:
"""
Result from a text completion request.
Returned by text_completion() in both streaming and non-streaming modes.
In streaming mode, text is None (chunks are delivered via the iterator).
In non-streaming mode, text contains the complete response.
Attributes:
text: Complete response text (None in streaming mode)
in_token: Input token count (None if not available)
out_token: Output token count (None if not available)
model: Model identifier (None if not available)
"""
text: Optional[str]
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass @dataclasses.dataclass
class ProvenanceEvent: class ProvenanceEvent:

View file

@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec from . text_completion_client import (
from . prompt_client import PromptClientSpec TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
)
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
from . triples_store_service import TriplesStoreService from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService from . document_embeddings_store_service import DocumentEmbeddingsStoreService

View file

@ -1,10 +1,22 @@
import json import json
import asyncio import asyncio
from dataclasses import dataclass
from typing import Optional, Any
from . request_response_spec import RequestResponse, RequestResponseSpec from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse from .. schema import PromptRequest, PromptResponse
@dataclass
class PromptResult:
response_type: str # "text", "json", or "jsonl"
text: Optional[str] = None # populated for "text"
object: Any = None # populated for "json"
objects: Optional[list] = None # populated for "jsonl"
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class PromptClient(RequestResponse): class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None): async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
if resp.text: return resp.text if resp.text:
return PromptResult(
response_type="text",
text=resp.text,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return json.loads(resp.object) parsed = json.loads(resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
else: else:
last_text = "" last_resp = None
last_object = None
async def forward_chunks(resp): async def forward_chunks(resp):
nonlocal last_text, last_object nonlocal last_resp
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
end_stream = getattr(resp, 'end_of_stream', False) end_stream = getattr(resp, 'end_of_stream', False)
if resp.text is not None: if resp.text is not None:
last_text = resp.text
if chunk_callback: if chunk_callback:
if asyncio.iscoroutinefunction(chunk_callback): if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text, end_stream) await chunk_callback(resp.text, end_stream)
else: else:
chunk_callback(resp.text, end_stream) chunk_callback(resp.text, end_stream)
elif resp.object:
last_object = resp.object last_resp = resp
return end_stream return end_stream
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
timeout=timeout timeout=timeout
) )
if last_text: if last_resp is None:
return last_text return PromptResult(response_type="text")
return json.loads(last_object) if last_object else None if last_resp.object:
parsed = json.loads(last_resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="text",
text=last_resp.text,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
async def extract_definitions(self, text, timeout=600): async def extract_definitions(self, text, timeout=600):
return await self.prompt( return await self.prompt(
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
response_schema = PromptResponse, response_schema = PromptResponse,
impl = PromptClient, impl = PromptClient,
) )

View file

@ -1,47 +1,71 @@
from dataclasses import dataclass
from typing import Optional
from . request_response_spec import RequestResponse, RequestResponseSpec from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse from .. schema import TextCompletionRequest, TextCompletionResponse
@dataclass
class TextCompletionResult:
text: Optional[str]
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class TextCompletionClient(RequestResponse): class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, streaming=False, timeout=600):
# If not streaming, use original behavior
if not streaming:
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = False
),
timeout=timeout
)
if resp.error: async def text_completion(self, system, prompt, timeout=600):
raise RuntimeError(resp.error.message)
return resp.response resp = await self.request(
# For streaming: collect all chunks and return complete response
full_response = ""
async def collect_chunks(resp):
nonlocal full_response
if resp.error:
raise RuntimeError(resp.error.message)
if resp.response:
full_response += resp.response
# Return True when end_of_stream is reached
return getattr(resp, 'end_of_stream', False)
await self.request(
TextCompletionRequest( TextCompletionRequest(
system = system, prompt = prompt, streaming = True system = system, prompt = prompt, streaming = False
), ),
recipient=collect_chunks,
timeout=timeout timeout=timeout
) )
return full_response if resp.error:
raise RuntimeError(resp.error.message)
return TextCompletionResult(
text = resp.response,
in_token = resp.in_token,
out_token = resp.out_token,
model = resp.model,
)
async def text_completion_stream(
self, system, prompt, handler, timeout=600,
):
"""
Streaming text completion. `handler` is an async callable invoked
once per chunk with the chunk's TextCompletionResponse. Returns a
TextCompletionResult with text=None and token counts / model taken
from the end_of_stream message.
"""
async def on_chunk(resp):
if resp.error:
raise RuntimeError(resp.error.message)
await handler(resp)
return getattr(resp, "end_of_stream", False)
final = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
),
recipient=on_chunk,
timeout=timeout,
)
return TextCompletionResult(
text = None,
in_token = final.in_token,
out_token = final.out_token,
model = final.model,
)
class TextCompletionClientSpec(RequestResponseSpec): class TextCompletionClientSpec(RequestResponseSpec):
def __init__( def __init__(
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
response_schema = TextCompletionResponse, response_schema = TextCompletionResponse,
impl = TextCompletionClient, impl = TextCompletionClient,
) )

View file

@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message: if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code} result["error"] = {"message": obj.error.message, "code": obj.error.code}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result return result
def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]: def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator):
# Always include end_of_stream flag for streaming support # Always include end_of_stream flag for streaming support
result["end_of_stream"] = getattr(obj, "end_of_stream", False) result["end_of_stream"] = getattr(obj, "end_of_stream", False)
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result return result
def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message: if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type} result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result return result
def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message: if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type} result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result return result
def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator):
def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]: def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]:
result = {"response": obj.response} result = {"response": obj.response}
if obj.in_token: if obj.in_token is not None:
result["in_token"] = obj.in_token result["in_token"] = obj.in_token
if obj.out_token: if obj.out_token is not None:
result["out_token"] = obj.out_token result["out_token"] = obj.out_token
if obj.model: if obj.model is not None:
result["model"] = obj.model result["model"] = obj.model
# Always include end_of_stream flag for streaming support # Always include end_of_stream flag for streaming support

View file

@ -66,5 +66,10 @@ class AgentResponse:
error: Error | None = None error: Error | None = None
# Token usage (populated on end_of_dialog message)
in_token: int | None = None
out_token: int | None = None
model: str | None = None
############################################################################ ############################################################################

View file

@ -17,9 +17,9 @@ class TextCompletionRequest:
class TextCompletionResponse: class TextCompletionResponse:
error: Error | None = None error: Error | None = None
response: str = "" response: str = ""
in_token: int = 0 in_token: int | None = None
out_token: int = 0 out_token: int | None = None
model: str = "" model: str | None = None
end_of_stream: bool = False # Indicates final message in stream end_of_stream: bool = False # Indicates final message in stream
############################################################################ ############################################################################

View file

@ -41,4 +41,9 @@ class PromptResponse:
# Indicates final message in stream # Indicates final message in stream
end_of_stream: bool = False end_of_stream: bool = False
# Token usage from the underlying text completion
in_token: int | None = None
out_token: int | None = None
model: str | None = None
############################################################################ ############################################################################

View file

@ -29,6 +29,9 @@ class GraphRagResponse:
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain" message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete end_of_session: bool = False # Entire session complete
in_token: int | None = None
out_token: int | None = None
model: str | None = None
############################################################################ ############################################################################
@ -52,3 +55,6 @@ class DocumentRagResponse:
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain" message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete end_of_session: bool = False # Entire session complete
in_token: int | None = None
out_token: int | None = None
model: str | None = None

View file

@ -272,7 +272,8 @@ def question(
url, question, flow_id, user, collection, url, question, flow_id, user, collection,
plan=None, state=None, group=None, pattern=None, plan=None, state=None, group=None, pattern=None,
verbose=False, streaming=True, verbose=False, streaming=True,
token=None, explainable=False, debug=False token=None, explainable=False, debug=False,
show_usage=False
): ):
# Explainable mode uses the API to capture and process provenance events # Explainable mode uses the API to capture and process provenance events
if explainable: if explainable:
@ -323,6 +324,7 @@ def question(
# Track last chunk type and current outputter for streaming # Track last chunk type and current outputter for streaming
last_chunk_type = None last_chunk_type = None
current_outputter = None current_outputter = None
last_answer_chunk = None
for chunk in response: for chunk in response:
chunk_type = chunk.chunk_type chunk_type = chunk.chunk_type
@ -357,6 +359,7 @@ def question(
current_outputter.word_buffer = "" current_outputter.word_buffer = ""
elif chunk_type == "final-answer": elif chunk_type == "final-answer":
print(content, end="", flush=True) print(content, end="", flush=True)
last_answer_chunk = chunk
# Close any remaining outputter # Close any remaining outputter
if current_outputter: if current_outputter:
@ -366,6 +369,14 @@ def question(
elif last_chunk_type == "final-answer": elif last_chunk_type == "final-answer":
print() print()
if show_usage and last_answer_chunk:
print(
f"Input tokens: {last_answer_chunk.in_token} "
f"Output tokens: {last_answer_chunk.out_token} "
f"Model: {last_answer_chunk.model}",
file=sys.stderr,
)
else: else:
# Non-streaming response - but agents use multipart messaging # Non-streaming response - but agents use multipart messaging
# so we iterate through the chunks (which are complete messages, not text chunks) # so we iterate through the chunks (which are complete messages, not text chunks)
@ -477,6 +488,12 @@ def main():
help='Show debug output for troubleshooting' help='Show debug output for troubleshooting'
) )
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -496,6 +513,7 @@ def main():
token = args.token, token = args.token,
explainable = args.explainable, explainable = args.explainable,
debug = args.debug, debug = args.debug,
show_usage = args.show_usage,
) )
except Exception as e: except Exception as e:

View file

@ -99,7 +99,8 @@ def question_explainable(
def question( def question(
url, flow_id, question_text, user, collection, doc_limit, url, flow_id, question_text, user, collection, doc_limit,
streaming=True, token=None, explainable=False, debug=False streaming=True, token=None, explainable=False, debug=False,
show_usage=False
): ):
# Explainable mode uses the API to capture and process provenance events # Explainable mode uses the API to capture and process provenance events
if explainable: if explainable:
@ -133,22 +134,40 @@ def question(
) )
# Stream output # Stream output
last_chunk = None
for chunk in response: for chunk in response:
print(chunk, end="", flush=True) print(chunk.content, end="", flush=True)
last_chunk = chunk
print() # Final newline print() # Final newline
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=sys.stderr,
)
finally: finally:
socket.close() socket.close()
else: else:
# Use REST API for non-streaming # Use REST API for non-streaming
flow = api.flow().id(flow_id) flow = api.flow().id(flow_id)
resp = flow.document_rag( result = flow.document_rag(
query=question_text, query=question_text,
user=user, user=user,
collection=collection, collection=collection,
doc_limit=doc_limit, doc_limit=doc_limit,
) )
print(resp) print(result.text)
if show_usage:
print(
f"Input tokens: {result.in_token} "
f"Output tokens: {result.out_token} "
f"Model: {result.model}",
file=sys.stderr,
)
def main(): def main():
@ -219,6 +238,12 @@ def main():
help='Show debug output for troubleshooting' help='Show debug output for troubleshooting'
) )
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -234,6 +259,7 @@ def main():
token=args.token, token=args.token,
explainable=args.explainable, explainable=args.explainable,
debug=args.debug, debug=args.debug,
show_usage=args.show_usage,
) )
except Exception as e: except Exception as e:

View file

@ -753,7 +753,7 @@ def question(
url, flow_id, question, user, collection, entity_limit, triple_limit, url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=50, max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None, edge_limit=25, streaming=True, token=None,
explainable=False, debug=False explainable=False, debug=False, show_usage=False
): ):
# Explainable mode uses the API to capture and process provenance events # Explainable mode uses the API to capture and process provenance events
@ -798,16 +798,26 @@ def question(
) )
# Stream output # Stream output
last_chunk = None
for chunk in response: for chunk in response:
print(chunk, end="", flush=True) print(chunk.content, end="", flush=True)
last_chunk = chunk
print() # Final newline print() # Final newline
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=sys.stderr,
)
finally: finally:
socket.close() socket.close()
else: else:
# Use REST API for non-streaming # Use REST API for non-streaming
flow = api.flow().id(flow_id) flow = api.flow().id(flow_id)
resp = flow.graph_rag( result = flow.graph_rag(
query=question, query=question,
user=user, user=user,
collection=collection, collection=collection,
@ -818,7 +828,15 @@ def question(
edge_score_limit=edge_score_limit, edge_score_limit=edge_score_limit,
edge_limit=edge_limit, edge_limit=edge_limit,
) )
print(resp) print(result.text)
if show_usage:
print(
f"Input tokens: {result.in_token} "
f"Output tokens: {result.out_token} "
f"Model: {result.model}",
file=sys.stderr,
)
def main(): def main():
@ -923,6 +941,12 @@ def main():
help='Show debug output for troubleshooting' help='Show debug output for troubleshooting'
) )
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -943,6 +967,7 @@ def main():
token=args.token, token=args.token,
explainable=args.explainable, explainable=args.explainable,
debug=args.debug, debug=args.debug,
show_usage=args.show_usage,
) )
except Exception as e: except Exception as e:

View file

@ -10,7 +10,8 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def query(url, flow_id, system, prompt, streaming=True, token=None): def query(url, flow_id, system, prompt, streaming=True, token=None,
show_usage=False):
# Create API client # Create API client
api = Api(url=url, token=token) api = Api(url=url, token=token)
@ -26,14 +27,29 @@ def query(url, flow_id, system, prompt, streaming=True, token=None):
) )
if streaming: if streaming:
# Stream output to stdout without newline last_chunk = None
for chunk in response: for chunk in response:
print(chunk, end="", flush=True) print(chunk.content, end="", flush=True)
# Add final newline after streaming last_chunk = chunk
print() print()
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=__import__('sys').stderr,
)
else: else:
# Non-streaming: print complete response print(response.text)
print(response)
if show_usage:
print(
f"Input tokens: {response.in_token} "
f"Output tokens: {response.out_token} "
f"Model: {response.model}",
file=__import__('sys').stderr,
)
finally: finally:
# Clean up socket connection # Clean up socket connection
@ -82,6 +98,12 @@ def main():
help='Disable streaming (default: streaming enabled)' help='Disable streaming (default: streaming enabled)'
) )
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -93,6 +115,7 @@ def main():
prompt=args.prompt[0], prompt=args.prompt[0],
streaming=not args.no_streaming, streaming=not args.no_streaming,
token=args.token, token=args.token,
show_usage=args.show_usage,
) )
except Exception as e: except Exception as e:

View file

@ -15,7 +15,8 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def query(url, flow_id, template_id, variables, streaming=True, token=None): def query(url, flow_id, template_id, variables, streaming=True, token=None,
show_usage=False):
# Create API client # Create API client
api = Api(url=url, token=token) api = Api(url=url, token=token)
@ -31,16 +32,30 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None):
) )
if streaming: if streaming:
# Stream output (prompt yields strings directly) last_chunk = None
for chunk in response: for chunk in response:
if chunk: if chunk.content:
print(chunk, end="", flush=True) print(chunk.content, end="", flush=True)
# Add final newline after streaming last_chunk = chunk
print() print()
if show_usage and last_chunk:
print(
f"Input tokens: {last_chunk.in_token} "
f"Output tokens: {last_chunk.out_token} "
f"Model: {last_chunk.model}",
file=__import__('sys').stderr,
)
else: else:
# Non-streaming: print complete response print(response.text)
print(response)
if show_usage:
print(
f"Input tokens: {response.in_token} "
f"Output tokens: {response.out_token} "
f"Model: {response.model}",
file=__import__('sys').stderr,
)
finally: finally:
# Clean up socket connection # Clean up socket connection
@ -92,6 +107,12 @@ specified multiple times''',
help='Disable streaming (default: streaming enabled for text responses)' help='Disable streaming (default: streaming enabled for text responses)'
) )
parser.add_argument(
'--show-usage',
action='store_true',
help='Show token usage and model on stderr'
)
args = parser.parse_args() args = parser.parse_args()
variables = {} variables = {}
@ -113,6 +134,7 @@ specified multiple times''',
variables=variables, variables=variables,
streaming=not args.no_streaming, streaming=not args.no_streaming,
token=args.token, token=args.token,
show_usage=args.show_usage,
) )
except Exception as e: except Exception as e:

View file

@ -53,7 +53,7 @@ class MetaRouter:
"general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""}, "general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""},
} }
async def identify_task_type(self, question, context): async def identify_task_type(self, question, context, usage=None):
""" """
Use the LLM to classify the question into one of the known task types. Use the LLM to classify the question into one of the known task types.
@ -71,7 +71,7 @@ class MetaRouter:
try: try:
client = context("prompt-request") client = context("prompt-request")
response = await client.prompt( result = await client.prompt(
id="task-type-classify", id="task-type-classify",
variables={ variables={
"question": question, "question": question,
@ -81,7 +81,9 @@ class MetaRouter:
], ],
}, },
) )
selected = response.strip().lower().replace('"', '').replace("'", "") if usage:
usage.track(result)
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in self.task_types: if selected in self.task_types:
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING) framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
@ -100,7 +102,7 @@ class MetaRouter:
) )
return DEFAULT_TASK_TYPE, framing return DEFAULT_TASK_TYPE, framing
async def select_pattern(self, question, task_type, context): async def select_pattern(self, question, task_type, context, usage=None):
""" """
Use the LLM to select the best execution pattern for this task type. Use the LLM to select the best execution pattern for this task type.
@ -120,7 +122,7 @@ class MetaRouter:
try: try:
client = context("prompt-request") client = context("prompt-request")
response = await client.prompt( result = await client.prompt(
id="pattern-select", id="pattern-select",
variables={ variables={
"question": question, "question": question,
@ -133,7 +135,9 @@ class MetaRouter:
], ],
}, },
) )
selected = response.strip().lower().replace('"', '').replace("'", "") if usage:
usage.track(result)
selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in valid_patterns: if selected in valid_patterns:
logger.info(f"MetaRouter: selected pattern '{selected}'") logger.info(f"MetaRouter: selected pattern '{selected}'")
@ -148,19 +152,20 @@ class MetaRouter:
logger.warning(f"MetaRouter: pattern selection failed: {e}") logger.warning(f"MetaRouter: pattern selection failed: {e}")
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
async def route(self, question, context): async def route(self, question, context, usage=None):
""" """
Full routing pipeline: identify task type, then select pattern. Full routing pipeline: identify task type, then select pattern.
Args: Args:
question: The user's query. question: The user's query.
context: UserAwareContext (flow wrapper). context: UserAwareContext (flow wrapper).
usage: Optional UsageTracker for token counting.
Returns: Returns:
(pattern, task_type, framing) tuple. (pattern, task_type, framing) tuple.
""" """
task_type, framing = await self.identify_task_type(question, context) task_type, framing = await self.identify_task_type(question, context, usage=usage)
pattern = await self.select_pattern(question, task_type, context) pattern = await self.select_pattern(question, task_type, context, usage=usage)
logger.info( logger.info(
f"MetaRouter: route result — " f"MetaRouter: route result — "
f"pattern={pattern}, task_type={task_type}, framing={framing!r}" f"pattern={pattern}, task_type={task_type}, framing={framing!r}"

View file

@ -65,6 +65,37 @@ class UserAwareContext:
return client return client
class UsageTracker:
"""Accumulates token usage across multiple prompt calls."""
def __init__(self):
self.total_in = 0
self.total_out = 0
self.last_model = None
def track(self, result):
"""Track usage from a PromptResult."""
if result is not None:
if getattr(result, "in_token", None) is not None:
self.total_in += result.in_token
if getattr(result, "out_token", None) is not None:
self.total_out += result.out_token
if getattr(result, "model", None) is not None:
self.last_model = result.model
@property
def in_token(self):
return self.total_in if self.total_in > 0 else None
@property
def out_token(self):
return self.total_out if self.total_out > 0 else None
@property
def model(self):
return self.last_model
class PatternBase: class PatternBase:
""" """
Shared infrastructure for all agent patterns. Shared infrastructure for all agent patterns.
@ -571,7 +602,8 @@ class PatternBase:
# ---- Response helpers --------------------------------------------------- # ---- Response helpers ---------------------------------------------------
async def prompt_as_answer(self, client, prompt_id, variables, async def prompt_as_answer(self, client, prompt_id, variables,
respond, streaming, message_id=""): respond, streaming, message_id="",
usage=None):
"""Call a prompt template, forwarding chunks as answer """Call a prompt template, forwarding chunks as answer
AgentResponse messages when streaming is enabled. AgentResponse messages when streaming is enabled.
@ -591,22 +623,28 @@ class PatternBase:
message_id=message_id, message_id=message_id,
)) ))
await client.prompt( result = await client.prompt(
id=prompt_id, id=prompt_id,
variables=variables, variables=variables,
streaming=True, streaming=True,
chunk_callback=on_chunk, chunk_callback=on_chunk,
) )
if usage:
usage.track(result)
return "".join(accumulated) return "".join(accumulated)
else: else:
return await client.prompt( result = await client.prompt(
id=prompt_id, id=prompt_id,
variables=variables, variables=variables,
) )
if usage:
usage.track(result)
return result.text
async def send_final_response(self, respond, streaming, answer_text, async def send_final_response(self, respond, streaming, answer_text,
already_streamed=False, message_id=""): already_streamed=False, message_id="",
usage=None):
"""Send the answer content and end-of-dialog marker. """Send the answer content and end-of-dialog marker.
Args: Args:
@ -614,7 +652,16 @@ class PatternBase:
via streaming callbacks (e.g. ReactPattern). Only the via streaming callbacks (e.g. ReactPattern). Only the
end-of-dialog marker is emitted. end-of-dialog marker is emitted.
message_id: Provenance URI for the answer entity. message_id: Provenance URI for the answer entity.
usage: UsageTracker with accumulated token counts.
""" """
usage_kwargs = {}
if usage:
usage_kwargs = {
"in_token": usage.in_token,
"out_token": usage.out_token,
"model": usage.model,
}
if streaming and not already_streamed: if streaming and not already_streamed:
# Answer wasn't streamed yet — send it as a chunk first # Answer wasn't streamed yet — send it as a chunk first
if answer_text: if answer_text:
@ -626,13 +673,14 @@ class PatternBase:
message_id=message_id, message_id=message_id,
)) ))
if streaming: if streaming:
# End-of-dialog marker # End-of-dialog marker with usage
await respond(AgentResponse( await respond(AgentResponse(
chunk_type="answer", chunk_type="answer",
content="", content="",
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,
message_id=message_id, message_id=message_id,
**usage_kwargs,
)) ))
else: else:
await respond(AgentResponse( await respond(AgentResponse(
@ -641,6 +689,7 @@ class PatternBase:
end_of_message=True, end_of_message=True,
end_of_dialog=True, end_of_dialog=True,
message_id=message_id, message_id=message_id,
**usage_kwargs,
)) ))
def build_next_request(self, request, history, session_id, collection, def build_next_request(self, request, history, session_id, collection,

View file

@ -18,7 +18,7 @@ from trustgraph.provenance import (
agent_synthesis_uri, agent_synthesis_uri,
) )
from . pattern_base import PatternBase from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,7 +35,10 @@ class PlanThenExecutePattern(PatternBase):
Subsequent calls execute the next pending plan step via ReACT. Subsequent calls execute the next pending plan step via ReACT.
""" """
async def iterate(self, request, respond, next, flow): async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False) streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -67,13 +70,13 @@ class PlanThenExecutePattern(PatternBase):
await self._planning_iteration( await self._planning_iteration(
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_uri, session_id, collection, streaming, session_uri,
iteration_num, iteration_num, usage=usage,
) )
else: else:
await self._execution_iteration( await self._execution_iteration(
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_uri, session_id, collection, streaming, session_uri,
iteration_num, plan, iteration_num, plan, usage=usage,
) )
def _extract_plan(self, history): def _extract_plan(self, history):
@ -98,7 +101,7 @@ class PlanThenExecutePattern(PatternBase):
async def _planning_iteration(self, request, respond, next, flow, async def _planning_iteration(self, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num): session_uri, iteration_num, usage=None):
"""Ask the LLM to produce a structured plan.""" """Ask the LLM to produce a structured plan."""
think = self.make_think_callback(respond, streaming) think = self.make_think_callback(respond, streaming)
@ -113,7 +116,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request") client = context("prompt-request")
# Use the plan-create prompt template # Use the plan-create prompt template
plan_steps = await client.prompt( result = await client.prompt(
id="plan-create", id="plan-create",
variables={ variables={
"question": request.question, "question": request.question,
@ -124,7 +127,10 @@ class PlanThenExecutePattern(PatternBase):
], ],
}, },
) )
if usage:
usage.track(result)
plan_steps = result.objects
# Validate we got a list # Validate we got a list
if not isinstance(plan_steps, list) or not plan_steps: if not isinstance(plan_steps, list) or not plan_steps:
logger.warning("plan-create returned invalid result, falling back to single step") logger.warning("plan-create returned invalid result, falling back to single step")
@ -187,7 +193,8 @@ class PlanThenExecutePattern(PatternBase):
async def _execution_iteration(self, request, respond, next, flow, async def _execution_iteration(self, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num, plan): session_uri, iteration_num, plan,
usage=None):
"""Execute the next pending plan step via single-shot tool call.""" """Execute the next pending plan step via single-shot tool call."""
pending_idx = self._find_next_pending_step(plan) pending_idx = self._find_next_pending_step(plan)
@ -198,6 +205,7 @@ class PlanThenExecutePattern(PatternBase):
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num, plan, session_uri, iteration_num, plan,
usage=usage,
) )
return return
@ -240,7 +248,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request") client = context("prompt-request")
# Single-shot: ask LLM which tool + arguments to use for this goal # Single-shot: ask LLM which tool + arguments to use for this goal
tool_call = await client.prompt( result = await client.prompt(
id="plan-step-execute", id="plan-step-execute",
variables={ variables={
"goal": goal, "goal": goal,
@ -258,7 +266,10 @@ class PlanThenExecutePattern(PatternBase):
], ],
}, },
) )
if usage:
usage.track(result)
tool_call = result.object
tool_name = tool_call.get("tool", "") tool_name = tool_call.get("tool", "")
tool_arguments = tool_call.get("arguments", {}) tool_arguments = tool_call.get("arguments", {})
@ -330,7 +341,8 @@ class PlanThenExecutePattern(PatternBase):
async def _synthesise(self, request, respond, next, flow, async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num, plan): session_uri, iteration_num, plan,
usage=None):
"""Synthesise a final answer from all completed plan step results.""" """Synthesise a final answer from all completed plan step results."""
think = self.make_think_callback(respond, streaming) think = self.make_think_callback(respond, streaming)
@ -365,6 +377,7 @@ class PlanThenExecutePattern(PatternBase):
respond=respond, respond=respond,
streaming=streaming, streaming=streaming,
message_id=synthesis_msg_id, message_id=synthesis_msg_id,
usage=usage,
) )
# Emit synthesis provenance (links back to last step result) # Emit synthesis provenance (links back to last step result)
@ -380,4 +393,5 @@ class PlanThenExecutePattern(PatternBase):
await self.send_final_response( await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming, respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id, message_id=synthesis_msg_id,
usage=usage,
) )

View file

@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager
from ..react.types import Action, Final from ..react.types import Action, Final
from ..tool_filter import get_next_state from ..tool_filter import get_next_state
from . pattern_base import PatternBase from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +37,10 @@ class ReactPattern(PatternBase):
result is appended to history and a next-request is emitted. result is appended to history and a next-request is emitted.
""" """
async def iterate(self, request, respond, next, flow): async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False) streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -121,6 +124,7 @@ class ReactPattern(PatternBase):
context=context, context=context,
streaming=streaming, streaming=streaming,
on_action=on_action, on_action=on_action,
usage=usage,
) )
logger.debug(f"Action: {act}") logger.debug(f"Action: {act}")
@ -144,6 +148,7 @@ class ReactPattern(PatternBase):
await self.send_final_response( await self.send_final_response(
respond, streaming, f, already_streamed=streaming, respond, streaming, f, already_streamed=streaming,
message_id=answer_msg_id, message_id=answer_msg_id,
usage=usage,
) )
return return

View file

@ -23,6 +23,7 @@ from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ..orchestrator.pattern_base import UsageTracker
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
@ -493,6 +494,8 @@ class Processor(AgentService):
async def agent_request(self, request, respond, next, flow): async def agent_request(self, request, respond, next, flow):
usage = UsageTracker()
try: try:
# Intercept subagent completion messages # Intercept subagent completion messages
@ -516,7 +519,7 @@ class Processor(AgentService):
if self.meta_router: if self.meta_router:
pattern, task_type, framing = await self.meta_router.route( pattern, task_type, framing = await self.meta_router.route(
request.question, context, request.question, context, usage=usage,
) )
else: else:
pattern = "react" pattern = "react"
@ -536,16 +539,16 @@ class Processor(AgentService):
# Dispatch to the selected pattern # Dispatch to the selected pattern
if pattern == "plan-then-execute": if pattern == "plan-then-execute":
await self.plan_pattern.iterate( await self.plan_pattern.iterate(
request, respond, next, flow, request, respond, next, flow, usage=usage,
) )
elif pattern == "supervisor": elif pattern == "supervisor":
await self.supervisor_pattern.iterate( await self.supervisor_pattern.iterate(
request, respond, next, flow, request, respond, next, flow, usage=usage,
) )
else: else:
# Default to react # Default to react
await self.react_pattern.iterate( await self.react_pattern.iterate(
request, respond, next, flow, request, respond, next, flow, usage=usage,
) )
except Exception as e: except Exception as e:

View file

@ -22,7 +22,7 @@ from trustgraph.provenance import (
agent_synthesis_uri, agent_synthesis_uri,
) )
from . pattern_base import PatternBase from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,7 +38,10 @@ class SupervisorPattern(PatternBase):
- "synthesise": triggered by aggregator with results in subagent_results - "synthesise": triggered by aggregator with results in subagent_results
""" """
async def iterate(self, request, respond, next, flow): async def iterate(self, request, respond, next, flow, usage=None):
if usage is None:
usage = UsageTracker()
streaming = getattr(request, 'streaming', False) streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@ -72,17 +75,19 @@ class SupervisorPattern(PatternBase):
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num, session_uri, iteration_num,
usage=usage,
) )
else: else:
await self._decompose_and_fanout( await self._decompose_and_fanout(
request, respond, next, flow, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num, session_uri, iteration_num,
usage=usage,
) )
async def _decompose_and_fanout(self, request, respond, next, flow, async def _decompose_and_fanout(self, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num): session_uri, iteration_num, usage=None):
"""Decompose the question into sub-goals and fan out subagents.""" """Decompose the question into sub-goals and fan out subagents."""
decompose_msg_id = agent_decomposition_uri(session_id) decompose_msg_id = agent_decomposition_uri(session_id)
@ -100,7 +105,7 @@ class SupervisorPattern(PatternBase):
client = context("prompt-request") client = context("prompt-request")
# Use the supervisor-decompose prompt template # Use the supervisor-decompose prompt template
goals = await client.prompt( result = await client.prompt(
id="supervisor-decompose", id="supervisor-decompose",
variables={ variables={
"question": request.question, "question": request.question,
@ -112,7 +117,10 @@ class SupervisorPattern(PatternBase):
], ],
}, },
) )
if usage:
usage.track(result)
goals = result.objects
# Validate result # Validate result
if not isinstance(goals, list): if not isinstance(goals, list):
goals = [] goals = []
@ -175,7 +183,7 @@ class SupervisorPattern(PatternBase):
async def _synthesise(self, request, respond, next, flow, async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming, session_id, collection, streaming,
session_uri, iteration_num): session_uri, iteration_num, usage=None):
"""Synthesise final answer from subagent results.""" """Synthesise final answer from subagent results."""
synthesis_msg_id = agent_synthesis_uri(session_id) synthesis_msg_id = agent_synthesis_uri(session_id)
@ -216,6 +224,7 @@ class SupervisorPattern(PatternBase):
respond=respond, respond=respond,
streaming=streaming, streaming=streaming,
message_id=synthesis_msg_id, message_id=synthesis_msg_id,
usage=usage,
) )
# Emit synthesis provenance (links back to all findings) # Emit synthesis provenance (links back to all findings)
@ -231,4 +240,5 @@ class SupervisorPattern(PatternBase):
await self.send_final_response( await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming, respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id, message_id=synthesis_msg_id,
usage=usage,
) )

View file

@ -170,7 +170,7 @@ class AgentManager:
raise ValueError(f"Could not parse response: {text}") raise ValueError(f"Could not parse response: {text}")
async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None): async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None):
logger.debug(f"calling reason: {question}") logger.debug(f"calling reason: {question}")
@ -255,11 +255,13 @@ class AgentManager:
client = context("prompt-request") client = context("prompt-request")
# Get streaming response # Get streaming response
response_text = await client.agent_react( prompt_result = await client.agent_react(
variables=variables, variables=variables,
streaming=True, streaming=True,
chunk_callback=on_chunk chunk_callback=on_chunk
) )
if usage:
usage.track(prompt_result)
# Finalize parser # Finalize parser
parser.finalize() parser.finalize()
@ -275,10 +277,13 @@ class AgentManager:
# Non-streaming path - get complete text and parse # Non-streaming path - get complete text and parse
client = context("prompt-request") client = context("prompt-request")
response_text = await client.agent_react( prompt_result = await client.agent_react(
variables=variables, variables=variables,
streaming=False streaming=False
) )
if usage:
usage.track(prompt_result)
response_text = prompt_result.text
logger.debug(f"Response text:\n{response_text}") logger.debug(f"Response text:\n{response_text}")
@ -292,7 +297,8 @@ class AgentManager:
raise RuntimeError(f"Failed to parse agent response: {e}") raise RuntimeError(f"Failed to parse agent response: {e}")
async def react(self, question, history, think, observe, context, async def react(self, question, history, think, observe, context,
streaming=False, answer=None, on_action=None): streaming=False, answer=None, on_action=None,
usage=None):
act = await self.reason( act = await self.reason(
question = question, question = question,
@ -302,6 +308,7 @@ class AgentManager:
think = think, think = think,
observe = observe, observe = observe,
answer = answer, answer = answer,
usage = usage,
) )
if isinstance(act, Final): if isinstance(act, Final):

View file

@ -78,9 +78,10 @@ class TextCompletionImpl:
async def invoke(self, **arguments): async def invoke(self, **arguments):
client = self.context("prompt-request") client = self.context("prompt-request")
logger.debug("Prompt question...") logger.debug("Prompt question...")
return await client.question( result = await client.question(
arguments.get("question") arguments.get("question")
) )
return result.text
# This tool implementation knows how to do MCP tool invocation. This uses # This tool implementation knows how to do MCP tool invocation. This uses
# the mcp-tool service. # the mcp-tool service.
@ -227,10 +228,11 @@ class PromptImpl:
async def invoke(self, **arguments): async def invoke(self, **arguments):
client = self.context("prompt-request") client = self.context("prompt-request")
logger.debug(f"Prompt template invocation: {self.template_id}...") logger.debug(f"Prompt template invocation: {self.template_id}...")
return await client.prompt( result = await client.prompt(
id=self.template_id, id=self.template_id,
variables=arguments variables=arguments
) )
return result.text
# This tool implementation invokes a dynamically configured tool service # This tool implementation invokes a dynamically configured tool service

View file

@ -117,10 +117,11 @@ class Processor(FlowProcessor):
try: try:
defs = await flow("prompt-request").extract_definitions( result = await flow("prompt-request").extract_definitions(
text = chunk text = chunk
) )
defs = result.objects
logger.debug(f"Definitions response: {defs}") logger.debug(f"Definitions response: {defs}")
if type(defs) != list: if type(defs) != list:

View file

@ -376,10 +376,11 @@ class Processor(FlowProcessor):
""" """
try: try:
# Call prompt service with simplified format prompt # Call prompt service with simplified format prompt
extraction_response = await flow("prompt-request").prompt( result = await flow("prompt-request").prompt(
id="extract-with-ontologies", id="extract-with-ontologies",
variables=prompt_variables variables=prompt_variables
) )
extraction_response = result.object
logger.debug(f"Simplified extraction response: {extraction_response}") logger.debug(f"Simplified extraction response: {extraction_response}")
# Parse response into structured format # Parse response into structured format

View file

@ -100,10 +100,11 @@ class Processor(FlowProcessor):
try: try:
rels = await flow("prompt-request").extract_relationships( result = await flow("prompt-request").extract_relationships(
text = chunk text = chunk
) )
rels = result.objects
logger.debug(f"Prompt response: {rels}") logger.debug(f"Prompt response: {rels}")
if type(rels) != list: if type(rels) != list:

View file

@ -148,11 +148,12 @@ class Processor(FlowProcessor):
schema_dict = row_schema_translator.encode(schema) schema_dict = row_schema_translator.encode(schema)
# Use prompt client to extract rows based on schema # Use prompt client to extract rows based on schema
objects = await flow("prompt-request").extract_objects( result = await flow("prompt-request").extract_objects(
schema=schema_dict, schema=schema_dict,
text=text text=text
) )
objects = result.objects
if not isinstance(objects, list): if not isinstance(objects, list):
return [] return []

View file

@ -11,7 +11,6 @@ import logging
from ...schema import Definition, Relationship, Triple from ...schema import Definition, Relationship, Triple
from ...schema import Topic from ...schema import Topic
from ...schema import PromptRequest, PromptResponse, Error from ...schema import PromptRequest, PromptResponse, Error
from ...schema import TextCompletionRequest, TextCompletionResponse
from ...base import FlowProcessor from ...base import FlowProcessor
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
@ -124,35 +123,26 @@ class Processor(FlowProcessor):
logger.debug(f"System prompt: {system}") logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}") logger.debug(f"User prompt: {prompt}")
# Use the text completion client with recipient handler
client = flow("text-completion-request")
async def forward_chunks(resp): async def forward_chunks(resp):
if resp.error:
raise RuntimeError(resp.error.message)
is_final = getattr(resp, 'end_of_stream', False) is_final = getattr(resp, 'end_of_stream', False)
# Always send a message if there's content OR if it's the final message # Always send a message if there's content OR if it's the final message
if resp.response or is_final: if resp.response or is_final:
# Forward each chunk immediately
r = PromptResponse( r = PromptResponse(
text=resp.response if resp.response else "", text=resp.response if resp.response else "",
object=None, object=None,
error=None, error=None,
end_of_stream=is_final, end_of_stream=is_final,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
) )
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
# Return True when end_of_stream await flow("text-completion-request").text_completion_stream(
return is_final system=system, prompt=prompt,
handler=forward_chunks,
await client.request( timeout=600,
TextCompletionRequest(
system=system, prompt=prompt, streaming=True
),
recipient=forward_chunks,
timeout=600
) )
# Return empty string since we already sent all chunks # Return empty string since we already sent all chunks
@ -167,17 +157,21 @@ class Processor(FlowProcessor):
return return
# Non-streaming path (original behavior) # Non-streaming path (original behavior)
usage = {}
async def llm(system, prompt): async def llm(system, prompt):
logger.debug(f"System prompt: {system}") logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}") logger.debug(f"User prompt: {prompt}")
resp = await flow("text-completion-request").text_completion(
system = system, prompt = prompt, streaming = False,
)
try: try:
return resp result = await flow("text-completion-request").text_completion(
system = system, prompt = prompt,
)
usage["in_token"] = result.in_token
usage["out_token"] = result.out_token
usage["model"] = result.model
return result.text
except Exception as e: except Exception as e:
logger.error(f"LLM Exception: {e}", exc_info=True) logger.error(f"LLM Exception: {e}", exc_info=True)
return None return None
@ -199,6 +193,9 @@ class Processor(FlowProcessor):
object=None, object=None,
error=None, error=None,
end_of_stream=True, end_of_stream=True,
in_token=usage.get("in_token", 0),
out_token=usage.get("out_token", 0),
model=usage.get("model", ""),
) )
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
@ -215,6 +212,9 @@ class Processor(FlowProcessor):
object=json.dumps(resp), object=json.dumps(resp),
error=None, error=None,
end_of_stream=True, end_of_stream=True,
in_token=usage.get("in_token", 0),
out_token=usage.get("out_token", 0),
model=usage.get("model", ""),
) )
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})

View file

@ -27,24 +27,27 @@ class Query:
def __init__( def __init__(
self, rag, user, collection, verbose, self, rag, user, collection, verbose,
doc_limit=20 doc_limit=20, track_usage=None,
): ):
self.rag = rag self.rag = rag
self.user = user self.user = user
self.collection = collection self.collection = collection
self.verbose = verbose self.verbose = verbose
self.doc_limit = doc_limit self.doc_limit = doc_limit
self.track_usage = track_usage
async def extract_concepts(self, query): async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding.""" """Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt( result = await self.rag.prompt_client.prompt(
"extract-concepts", "extract-concepts",
variables={"query": query} variables={"query": query}
) )
if self.track_usage:
self.track_usage(result)
concepts = [] concepts = []
if isinstance(response, str): if result.text:
for line in response.strip().split('\n'): for line in result.text.strip().split('\n'):
line = line.strip() line = line.strip()
if line: if line:
concepts.append(line) concepts.append(line)
@ -167,8 +170,23 @@ class DocumentRag:
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns: Returns:
str: The synthesized answer text tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
""" """
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose: if self.verbose:
logger.debug("Constructing prompt...") logger.debug("Constructing prompt...")
@ -191,7 +209,7 @@ class DocumentRag:
q = Query( q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose, rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit doc_limit=doc_limit, track_usage=track_usage,
) )
# Extract concepts from query (grounding step) # Extract concepts from query (grounding step)
@ -228,19 +246,22 @@ class DocumentRag:
accumulated_chunks.append(chunk) accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream) await chunk_callback(chunk, end_of_stream)
resp = await self.prompt_client.document_prompt( synthesis_result = await self.prompt_client.document_prompt(
query=query, query=query,
documents=docs, documents=docs,
streaming=True, streaming=True,
chunk_callback=accumulating_callback chunk_callback=accumulating_callback
) )
track_usage(synthesis_result)
# Combine all chunks into full response # Combine all chunks into full response
resp = "".join(accumulated_chunks) resp = "".join(accumulated_chunks)
else: else:
resp = await self.prompt_client.document_prompt( synthesis_result = await self.prompt_client.document_prompt(
query=query, query=query,
documents=docs documents=docs
) )
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose: if self.verbose:
logger.debug("Query processing complete") logger.debug("Query processing complete")
@ -273,5 +294,11 @@ class DocumentRag:
if self.verbose: if self.verbose:
logger.debug(f"Emitted explain for session {session_id}") logger.debug(f"Emitted explain for session {session_id}")
return resp usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage

View file

@ -200,7 +200,7 @@ class Processor(FlowProcessor):
# Query with streaming enabled # Query with streaming enabled
# All chunks (including final one with end_of_stream=True) are sent via callback # All chunks (including final one with end_of_stream=True) are sent via callback
await self.rag.query( response, usage = await self.rag.query(
v.query, v.query,
user=v.user, user=v.user,
collection=v.collection, collection=v.collection,
@ -217,12 +217,15 @@ class Processor(FlowProcessor):
response=None, response=None,
end_of_session=True, end_of_session=True,
message_type="end", message_type="end",
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
), ),
properties={"id": id} properties={"id": id}
) )
else: else:
# Non-streaming path (existing behavior) # Non-streaming path - single response with answer and token usage
response = await self.rag.query( response, usage = await self.rag.query(
v.query, v.query,
user=v.user, user=v.user,
collection=v.collection, collection=v.collection,
@ -233,11 +236,15 @@ class Processor(FlowProcessor):
await flow("response").send( await flow("response").send(
DocumentRagResponse( DocumentRagResponse(
response = response, response=response,
end_of_stream = True, end_of_stream=True,
error = None end_of_session=True,
error=None,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
), ),
properties = {"id": id} properties={"id": id}
) )
logger.info("Request processing complete") logger.info("Request processing complete")

View file

@ -121,7 +121,7 @@ class Query:
def __init__( def __init__(
self, rag, user, collection, verbose, self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000, entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2, max_path_length=2, track_usage=None,
): ):
self.rag = rag self.rag = rag
self.user = user self.user = user
@ -131,17 +131,20 @@ class Query:
self.triple_limit = triple_limit self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length self.max_path_length = max_path_length
self.track_usage = track_usage
async def extract_concepts(self, query): async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding.""" """Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt( result = await self.rag.prompt_client.prompt(
"extract-concepts", "extract-concepts",
variables={"query": query} variables={"query": query}
) )
if self.track_usage:
self.track_usage(result)
concepts = [] concepts = []
if isinstance(response, str): if result.text:
for line in response.strip().split('\n'): for line in result.text.strip().split('\n'):
line = line.strip() line = line.strip()
if line: if line:
concepts.append(line) concepts.append(line)
@ -609,8 +612,24 @@ class GraphRag:
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns: Returns:
str: The synthesized answer text tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
""" """
# Accumulate token usage across all prompt calls
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose: if self.verbose:
logger.debug("Constructing prompt...") logger.debug("Constructing prompt...")
@ -641,6 +660,7 @@ class GraphRag:
triple_limit = triple_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length, max_path_length = max_path_length,
track_usage = track_usage,
) )
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
@ -751,21 +771,22 @@ class GraphRag:
logger.debug(f"Built edge map with {len(edge_map)} edges") logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance # Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_response = await self.prompt_client.prompt( scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring", "kg-edge-scoring",
variables={ variables={
"query": query, "query": query,
"knowledge": edges_with_ids "knowledge": edges_with_ids
} }
) )
track_usage(scoring_result)
if self.verbose: if self.verbose:
logger.debug(f"Edge scoring response: {scoring_response}") logger.debug(f"Edge scoring result: {scoring_result}")
# Parse scoring response to get edge IDs with scores # Parse scoring response (jsonl) to get edge IDs with scores
scored_edges = [] scored_edges = []
def parse_scored_edge(obj): for obj in scoring_result.objects or []:
if isinstance(obj, dict) and "id" in obj and "score" in obj: if isinstance(obj, dict) and "id" in obj and "score" in obj:
try: try:
score = int(obj["score"]) score = int(obj["score"])
@ -773,21 +794,6 @@ class GraphRag:
score = 0 score = 0
scored_edges.append({"id": obj["id"], "score": score}) scored_edges.append({"id": obj["id"], "score": score})
if isinstance(scoring_response, list):
for obj in scoring_response:
parse_scored_edge(obj)
elif isinstance(scoring_response, str):
for line in scoring_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_scored_edge(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge scoring line: {line}"
)
# Select top N edges by score # Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True) scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit] top_edges = scored_edges[:edge_limit]
@ -821,25 +827,30 @@ class GraphRag:
] ]
# Run reasoning and document tracing concurrently # Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt( async def _get_reasoning():
"kg-edge-reasoning", result = await self.prompt_client.prompt(
variables={ "kg-edge-reasoning",
"query": query, variables={
"knowledge": selected_edges_with_ids "query": query,
} "knowledge": selected_edges_with_ids
) }
)
track_usage(result)
return result
reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris) doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather( reasoning_result, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True reasoning_task, doc_trace_task, return_exceptions=True
) )
# Handle exceptions from gather # Handle exceptions from gather
if isinstance(reasoning_response, Exception): if isinstance(reasoning_result, Exception):
logger.warning( logger.warning(
f"Edge reasoning failed: {reasoning_response}" f"Edge reasoning failed: {reasoning_result}"
) )
reasoning_response = "" reasoning_result = None
if isinstance(source_documents, Exception): if isinstance(source_documents, Exception):
logger.warning( logger.warning(
f"Document tracing failed: {source_documents}" f"Document tracing failed: {source_documents}"
@ -848,29 +859,15 @@ class GraphRag:
if self.verbose: if self.verbose:
logger.debug(f"Edge reasoning response: {reasoning_response}") logger.debug(f"Edge reasoning result: {reasoning_result}")
# Parse reasoning response and build explainability data # Parse reasoning response (jsonl) and build explainability data
reasoning_map = {} reasoning_map = {}
def parse_reasoning(obj): if reasoning_result is not None:
if isinstance(obj, dict) and "id" in obj: for obj in reasoning_result.objects or []:
reasoning_map[obj["id"]] = obj.get("reasoning", "") if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
if isinstance(reasoning_response, list):
for obj in reasoning_response:
parse_reasoning(obj)
elif isinstance(reasoning_response, str):
for line in reasoning_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_reasoning(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge reasoning line: {line}"
)
selected_edges_with_reasoning = [] selected_edges_with_reasoning = []
for eid in selected_ids: for eid in selected_ids:
@ -919,19 +916,22 @@ class GraphRag:
accumulated_chunks.append(chunk) accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream) await chunk_callback(chunk, end_of_stream)
await self.prompt_client.prompt( synthesis_result = await self.prompt_client.prompt(
"kg-synthesis", "kg-synthesis",
variables=synthesis_variables, variables=synthesis_variables,
streaming=True, streaming=True,
chunk_callback=accumulating_callback chunk_callback=accumulating_callback
) )
track_usage(synthesis_result)
# Combine all chunks into full response # Combine all chunks into full response
resp = "".join(accumulated_chunks) resp = "".join(accumulated_chunks)
else: else:
resp = await self.prompt_client.prompt( synthesis_result = await self.prompt_client.prompt(
"kg-synthesis", "kg-synthesis",
variables=synthesis_variables, variables=synthesis_variables,
) )
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose: if self.verbose:
logger.debug("Query processing complete") logger.debug("Query processing complete")
@ -964,5 +964,11 @@ class GraphRag:
if self.verbose: if self.verbose:
logger.debug(f"Emitted explain for session {session_id}") logger.debug(f"Emitted explain for session {session_id}")
return resp usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage

View file

@ -332,7 +332,7 @@ class Processor(FlowProcessor):
) )
# Query with streaming and real-time explain # Query with streaming and real-time explain
response = await rag.query( response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection, query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit, entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
@ -348,7 +348,7 @@ class Processor(FlowProcessor):
else: else:
# Non-streaming path with real-time explain # Non-streaming path with real-time explain
response = await rag.query( response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection, query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit, entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
@ -360,23 +360,30 @@ class Processor(FlowProcessor):
parent_uri = v.parent_uri, parent_uri = v.parent_uri,
) )
# Send chunk with response # Send single response with answer and token usage
await flow("response").send( await flow("response").send(
GraphRagResponse( GraphRagResponse(
message_type="chunk", message_type="chunk",
response=response, response=response,
end_of_stream=True, end_of_stream=True,
error=None, end_of_session=True,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
), ),
properties={"id": id} properties={"id": id}
) )
return
# Send final message to close session # Streaming: send final message to close session with token usage
await flow("response").send( await flow("response").send(
GraphRagResponse( GraphRagResponse(
message_type="chunk", message_type="chunk",
response="", response="",
end_of_session=True, end_of_session=True,
in_token=usage.get("in_token"),
out_token=usage.get("out_token"),
model=usage.get("model"),
), ),
properties={"id": id} properties={"id": id}
) )