Expose LLM token usage (in_token, out_token, model) across all

service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
This commit is contained in:
Cyber MacGeddon 2026-04-11 20:22:35 +01:00
parent ffe310af7c
commit 56d700f301
60 changed files with 1252 additions and 577 deletions

View file

@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
from trustgraph.agent.react.types import Action, Final, Tool, Argument
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
# Mock prompt client
prompt_client = AsyncMock()
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for information about machine learning
Action: knowledge_query
Args: {
"question": "What is machine learning?"
}"""
)
# Mock graph RAG client
graph_rag_client = AsyncMock()
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
# Mock text completion client
text_completion_client = AsyncMock()
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
text_completion_client.question.return_value = PromptResult(
response_type="text",
text="Machine learning involves algorithms that improve through experience."
)
# Mock MCP tool client
mcp_tool_client = AsyncMock()
@ -147,8 +154,11 @@ Args: {
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
"""Test agent manager returning final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have enough information to answer the question
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
)
question = "What is machine learning?"
history = []
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
"""Test ReAct cycle ending with final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide a direct answer
Final Answer: Machine learning is a branch of artificial intelligence."""
)
question = "What is machine learning?"
history = []
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
for tool_name, expected_service in tool_scenarios:
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: I need to use {tool_name}
Action: {tool_name}
Args: {{
"question": "test question"
}}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -284,11 +300,14 @@ Args: {{
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
"""Test agent manager error handling for unknown tool"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to use an unknown tool
Action: unknown_tool
Args: {
"param": "value"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -321,11 +340,14 @@ Args: {
question = "Find information about AI and summarize it"
# Mock multi-step reasoning
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for AI information first
Action: knowledge_query
Args: {
"question": "What is artificial intelligence?"
}"""
)
# Act
action = await agent_manager.reason(question, [], mock_flow_context)
@ -372,9 +394,12 @@ Args: {
# Format arguments as JSON
import json
args_json = json.dumps(test_case['arguments'], indent=4)
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: Using {test_case['action']}
Action: {test_case['action']}
Args: {args_json}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -507,7 +532,10 @@ Args: {
]
for test_case in test_cases:
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=test_case["response"]
)
if test_case["error_contains"]:
# Should raise an error
@ -527,13 +555,16 @@ Args: {
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
"""Test edge cases in text parsing"""
# Test response with markdown code blocks
mock_flow_context("prompt-request").agent_react.return_value = """```
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""```
Thought: I need to search for information
Action: knowledge_query
Args: {
"question": "What is AI?"
}
```"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -541,15 +572,18 @@ Args: {
assert action.name == "knowledge_query"
# Test response with extra whitespace
mock_flow_context("prompt-request").agent_react.return_value = """
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""
Thought: I need to think about this
Action: knowledge_query
Thought: I need to think about this
Action: knowledge_query
Args: {
"question": "test"
}
"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -560,7 +594,9 @@ Args: {
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
"""Test handling of multi-line thoughts and final answers"""
# Multi-line thought
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to consider multiple factors:
1. The user's question is complex
2. I should search for comprehensive information
3. This requires using the knowledge query tool
@ -568,6 +604,7 @@ Action: knowledge_query
Args: {
"question": "complex query"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -575,13 +612,16 @@ Args: {
assert "knowledge query tool" in action.thought
# Multi-line final answer
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have gathered enough information
Final Answer: Here is a comprehensive answer:
1. First point about the topic
2. Second point with details
3. Final conclusion
This covers all aspects of the question."""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -593,13 +633,16 @@ This covers all aspects of the question."""
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
"""Test JSON arguments with special characters and edge cases"""
# Test with special characters in JSON (properly escaped)
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Processing special characters
Action: knowledge_query
Args: {
"question": "What about \\"quotes\\" and 'apostrophes'?",
"context": "Line 1\\nLine 2\\tTabbed",
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -608,7 +651,9 @@ Args: {
assert "@#$%^&*" in action.arguments["special"]
# Test with nested JSON
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Complex arguments
Action: web_search
Args: {
"query": "test",
@ -621,6 +666,7 @@ Args: {
}
}
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -632,7 +678,9 @@ Args: {
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
"""Test final answers that contain JSON-like content"""
# Final answer with JSON content
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide the data in JSON format
Final Answer: {
"result": "success",
"data": {
@ -642,6 +690,7 @@ Final Answer: {
},
"confidence": 0.95
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -792,11 +841,14 @@ Final Answer: {
agent = AgentManager(tools=custom_tools, additional_context="")
# Mock response for custom collection query
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search in the research papers
Action: knowledge_query_custom
Args: {
"question": "Latest AI research?"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl
from trustgraph.agent.react.types import Tool, Argument
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks,
assert_streaming_chunks_valid,
@ -51,10 +52,10 @@ Args: {
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.agent_react.side_effect = agent_react_streaming
return client
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk + " ", is_final)
return response
return response
return PromptResult(response_type="text", text=response)
return PromptResult(response_type="text", text=response)
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Error
)
from trustgraph.agent.react.service import Processor
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from New York using structured query
Action: structured-query
Args: {
"question": "Find all customers from New York"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -173,11 +177,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query for a table that might not exist
Action: structured-query
Args: {
"question": "Find data from a table that doesn't exist"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -250,11 +257,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from California first
Action: structured-query
Args: {
"question": "Find all customers from California"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -339,11 +349,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query the sales data
Action: structured-query
Args: {
"question": "Query the sales data for recent transactions"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -447,11 +460,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to get customer information
Action: structured-query
Args: {
"question": "Get customer information and format it nicely"
}"""
)
# Set up flow context routing
def flow_context(service_name):

View file

@ -10,6 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.document_prompt.return_value = (
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
client.document_prompt.return_value = PromptResult(
response_type="text",
text=(
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
)
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
)
# Verify final response
result, usage = result
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
mock_prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="I couldn't find any relevant documents for your query."
)
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
result_text, usage = result
assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks
result_text, usage = result
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
assert result_text == full_from_chunks
# Verify content is reasonable
assert len(result) > 0
assert len(result_text) > 0
@pytest.mark.asyncio
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
non_streaming_text, _ = non_streaming_result
streaming_text, _ = streaming_result
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
)
# Assert
result_text, usage = result
assert callback.call_count > 0
assert result is not None
assert result_text is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
result_text, usage = result
assert isinstance(result_text, str)
@pytest.mark.asyncio
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
)
# Assert - Should still produce streamed response
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
@pytest.mark.asyncio
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
)
# Assert
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
# Verify doc_limit was passed correctly

View file

@ -12,6 +12,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return "" # No edges scored
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
return PromptResult(
response_type="text",
text=(
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
)
return ""
return PromptResult(response_type="text", text="")
client.prompt.side_effect = mock_prompt
return client
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
response, usage = response
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
return full_text
return ""
return PromptResult(response_type="text", text=full_text)
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
)
# Assert
response, usage = response
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_response == non_streaming_response
non_streaming_text, _ = non_streaming_response
streaming_text, _ = streaming_response
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_response
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
# Assert - Should complete without error
assert response is not None
assert isinstance(response, str)
response_text, usage = response
assert isinstance(response_text, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,

View file

@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
# Mock prompt client for definitions extraction
prompt_client = AsyncMock()
prompt_client.extract_definitions.return_value = [
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
prompt_client.extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
)
# Mock prompt client for relationships extraction
prompt_client.extract_relationships.return_value = [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
prompt_client.extract_relationships.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
)
# Mock producers for output streams
triples_producer = AsyncMock()
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of empty extraction results"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = []
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[]
)
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of invalid extraction response format"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="text",
text="invalid format"
) # Should be jsonl with objects list
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
"""Test entity filtering and validation in extraction"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = [
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
)
sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"),

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
)
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
if schema_name == "customer_records":
if "john" in text.lower():
return [
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
)
elif "jane" in text.lower():
return [
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
)
else:
return []
return PromptResult(response_type="jsonl", objects=[])
elif schema_name == "product_catalog":
if "laptop" in text.lower():
return [
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
)
elif "book" in text.lower():
return [
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
)
else:
return []
return []
return PromptResult(response_type="jsonl", objects=[])
return PromptResult(response_type="jsonl", objects=[])
prompt_client.extract_objects.side_effect = mock_extract_objects

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.prompt.template.service import Processor
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
from trustgraph.base.text_completion_client import TextCompletionResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -27,34 +28,52 @@ class TestPromptStreaming:
# Mock text completion client with streaming
text_completion_client = AsyncMock()
async def streaming_request(request, recipient=None, timeout=600):
"""Simulate streaming text completion"""
if request.streaming and recipient:
# Simulate streaming chunks
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
# Streaming chunks to send
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
for i, chunk_text in enumerate(chunks):
is_final = (i == len(chunks) - 1)
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=is_final
)
final = await recipient(response)
if final:
break
# Final empty chunk
await recipient(TextCompletionResponse(
response="",
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
"""Simulate streaming text completion via text_completion_stream"""
for i, chunk_text in enumerate(chunks):
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=True
))
end_of_stream=False
)
await handler(response)
text_completion_client.request = streaming_request
# Send final empty chunk with end_of_stream
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
return TextCompletionResult(
text=None,
in_token=10,
out_token=9,
model="test-model",
)
async def non_streaming_text_completion(system, prompt, timeout=600):
"""Simulate non-streaming text completion"""
full_text = "Machine learning is a field of artificial intelligence."
return TextCompletionResult(
text=full_text,
in_token=10,
out_token=9,
model="test-model",
)
text_completion_client.text_completion_stream = AsyncMock(
side_effect=streaming_text_completion_stream
)
text_completion_client.text_completion = AsyncMock(
side_effect=non_streaming_text_completion
)
# Mock response producer
response_producer = AsyncMock()
@ -156,14 +175,6 @@ class TestPromptStreaming:
consumer = MagicMock()
# Mock non-streaming text completion
text_completion_client = mock_flow_context_streaming("text-completion-request")
async def non_streaming_text_completion(system, prompt, streaming=False):
return "AI is the simulation of human intelligence in machines."
text_completion_client.text_completion = non_streaming_text_completion
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
@ -218,17 +229,12 @@ class TestPromptStreaming:
# Mock text completion client that raises an error
text_completion_client = AsyncMock()
async def failing_request(request, recipient=None, timeout=600):
if recipient:
# Send error response with proper Error schema
error_response = TextCompletionResponse(
response="",
error=Error(message="Text completion error", type="processing_error"),
end_of_stream=True
)
await recipient(error_response)
async def failing_stream(system, prompt, handler, timeout=600):
raise RuntimeError("Text completion error")
text_completion_client.request = failing_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=failing_stream
)
# Mock response producer to capture error response
response_producer = AsyncMock()
@ -255,22 +261,15 @@ class TestPromptStreaming:
consumer = MagicMock()
# Act - The service catches errors and sends error responses, doesn't raise
# Act - The service catches errors and sends an error PromptResponse
await prompt_processor_streaming.on_request(message, consumer, context)
# Assert - Verify error response was sent
assert response_producer.send.call_count > 0
# Check that at least one response contains an error
error_sent = False
for call in response_producer.send.call_args_list:
response = call.args[0]
if hasattr(response, 'error') and response.error:
error_sent = True
assert "Text completion error" in response.error.message
break
assert error_sent, "Expected error response to be sent"
# Assert - error response was sent
calls = response_producer.send.call_args_list
assert len(calls) > 0
error_response = calls[-1].args[0]
assert error_response.error is not None
assert "Text completion error" in error_response.error.message
@pytest.mark.asyncio
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
@ -315,21 +314,22 @@ class TestPromptStreaming:
# Mock text completion that sends empty chunks
text_completion_client = AsyncMock()
async def empty_streaming_request(request, recipient=None, timeout=600):
if request.streaming and recipient:
# Send empty chunk followed by final marker
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
async def empty_streaming(system, prompt, handler, timeout=600):
# Send empty chunk followed by final marker
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
text_completion_client.request = empty_streaming_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=empty_streaming
)
response_producer = AsyncMock()
def context_router(service_name):
@ -401,4 +401,4 @@ class TestPromptStreaming:
# Verify chunks concatenate to expected result
full_text = "".join(chunk_texts)
assert full_text == "Machine learning is a field of artificial intelligence"
assert full_text == "Machine learning is a field of artificial intelligence."

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
from trustgraph.base import PromptResult
class TestGraphRagStreamingProtocol:
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Edge selection returns empty (no edges selected)
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
return PromptResult(response_type="text", text="")
else:
return "The answer is here."
return ""
return PromptResult(response_type="text", text="The answer is here.")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
await chunk_callback("Document", False)
await chunk_callback(" summary", False)
await chunk_callback(".", True) # Non-empty final chunk
return ""
return PromptResult(response_type="text", text="")
else:
return "Document summary."
return PromptResult(response_type="text", text="Document summary.")
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
return ""
return PromptResult(response_type="text", text="")
else:
return "textmore"
return ""
return PromptResult(response_type="text", text="textmore")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_with_empties

View file

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
from trustgraph.base import PromptResult
def _make_config(patterns=None, task_types=None):
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
client.prompt = AsyncMock(
return_value=PromptResult(response_type="text", text=prompt_response)
)
def context(service_name):
return client
@ -274,8 +277,8 @@ class TestRoute:
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
return PromptResult(response_type="text", text="research")
return PromptResult(response_type="text", text="plan-then-execute")
client.prompt = mock_prompt
context = lambda name: client

View file

@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(action)
return action
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(Action(
thought="done", name="final",
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
],
)
def flow_factory(name):
if name == "prompt-request":
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
mock_prompt_client.prompt.return_value = PromptResult(
response_type="json",
object={
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
},
)
def flow_factory(name):
if name == "prompt-request":
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
def flow_factory(name):
if name == "prompt-request":
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
"What is quantum computing?",
"What are qubits?",
],
)
def flow_factory(name):
if name == "prompt-request":
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
def flow_factory(name):
if name == "prompt-request":
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B", "Goal C"],
)
def flow_factory(name):
if name == "prompt-request":

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
from trustgraph.base import PromptResult
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
if isinstance(prompt_result, list):
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
else:
wrapped = PromptResult(response_type="text", text=prompt_result)
mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result
return_value=wrapped
)
def flow(name):

View file

@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
# ---------------------------------------------------------------------------
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result
return_value=PromptResult(
response_type="jsonl",
objects=prompt_result,
)
)
def flow(name):

View file

@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
from trustgraph.base import PromptResult
# Sample chunk content mapping for tests
@ -132,7 +133,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
query = Query(
rag=mock_rag,
@ -157,7 +158,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@ -258,7 +259,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -273,7 +274,7 @@ class TestQuery:
expected_response = "This is the document RAG response"
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -315,7 +316,8 @@ class TestQuery:
assert "Relevant document content" in docs
assert "Another document" in docs
assert result == expected_response
result_text, usage = result
assert result_text == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
@ -325,7 +327,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
@ -333,7 +335,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -352,7 +354,8 @@ class TestQuery:
collection="default" # Default collection
)
assert result == "Default response"
result_text, usage = result
assert result_text == "Default response"
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
@ -401,7 +404,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
@ -409,7 +412,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -428,7 +431,8 @@ class TestQuery:
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response"
result_text, usage = result
assert result_text == "Verbose RAG response"
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
@ -469,11 +473,11 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -490,7 +494,8 @@ class TestQuery:
documents=[]
)
assert result == "No documents found response"
result_text, usage = result
assert result_text == "No documents found response"
@pytest.mark.asyncio
async def test_get_vectors_with_verbose(self):
@ -525,7 +530,7 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
@ -541,7 +546,7 @@ class TestQuery:
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -584,7 +589,8 @@ class TestQuery:
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
assert result == final_response
result_text, usage = result
assert result_text == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -89,8 +90,8 @@ def build_mock_clients():
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return "return policy\nrefund"
return ""
return PromptResult(response_type="text", text="return policy\nrefund")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -113,8 +114,9 @@ def build_mock_clients():
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
prompt_client.document_prompt.return_value = (
"Items can be returned within 30 days for a full refund."
prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
assert result == "Items can be returned within 30 days for a full refund."
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund."
result_text, usage = await rag.query(query="What is the return policy?")
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -34,7 +34,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "test response"
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom user/collection
msg = MagicMock()
@ -97,7 +97,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A document about cats."
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
# Setup message with non-streaming request
msg = MagicMock()
@ -130,4 +130,5 @@ class TestDocumentRagService:
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.end_of_session is True
assert sent_response.error is None

View file

@ -7,6 +7,7 @@ import unittest.mock
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
from trustgraph.base import PromptResult
class TestGraphRag:
@ -172,7 +173,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
query = Query(
rag=mock_rag,
@ -196,7 +197,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@ -220,7 +221,7 @@ class TestQuery:
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -565,14 +566,14 @@ class TestQuery:
# Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
elif prompt_name == "kg-synthesis":
return expected_response
return ""
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt = mock_prompt
@ -607,7 +608,8 @@ class TestQuery:
explain_callback=collect_provenance
)
assert response == expected_response
response_text, usage = response
assert response_text == expected_response
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5

View file

@ -13,6 +13,7 @@ from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -136,24 +137,36 @@ def build_mock_clients():
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return prompt_responses["extract-concepts"]
return PromptResult(
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis":
return synthesis_answer
return ""
return PromptResult(
response_type="text",
text=synthesis_answer,
)
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
clients = build_mock_clients()
rag = GraphRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -44,7 +44,7 @@ class TestGraphRagService:
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
await explain_callback([], "urn:trustgraph:prov:selection:test")
await explain_callback([], "urn:trustgraph:prov:answer:test")
return "A small domesticated mammal."
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -79,8 +79,8 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
assert mock_response_producer.send.call_count == 6
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
assert mock_response_producer.send.call_count == 5
# First 4 messages are explain (emitted in real-time during query)
for i in range(4):
@ -88,17 +88,12 @@ class TestGraphRagService:
assert prov_msg.message_type == "explain"
assert prov_msg.explain_id is not None
# 5th message is chunk with response
# 5th message is chunk with response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "A small domesticated mammal."
assert chunk_msg.end_of_stream is True
# 6th message is empty chunk with end_of_session=True
close_msg = mock_response_producer.send.call_args_list[5][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True
# Verify provenance triples were sent to provenance queue
assert mock_provenance_producer.send.call_count == 4
@ -187,7 +182,7 @@ class TestGraphRagService:
async def mock_query(**kwargs):
# Don't call explain_callback
return "Response text"
return "Response text", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -218,17 +213,12 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 2 messages (chunk + empty chunk to close)
assert mock_response_producer.send.call_count == 2
# Verify: 1 combined message (chunk with end_of_session)
assert mock_response_producer.send.call_count == 1
# First is the response chunk
# Single message has response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "Response text"
assert chunk_msg.end_of_stream is True
# Second is empty chunk to close session
close_msg = mock_response_producer.send.call_args_list[1][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True