mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Expose LLM token usage (in_token, out_token, model) across all service layers Propagate token counts from LLM services through the prompt, text-completion, graph-RAG, document-RAG, and agent orchestrator pipelines to the API gateway and Python SDK. All fields are Optional — None means "not available", distinguishing from a real zero count. Key changes: - Schema: Add in_token/out_token/model to TextCompletionResponse, PromptResponse, GraphRagResponse, DocumentRagResponse, AgentResponse - TextCompletionClient: New TextCompletionResult return type. Split into text_completion() (non-streaming) and text_completion_stream() (streaming with per-chunk handler callback) - PromptClient: New PromptResult with response_type (text/json/jsonl), typed fields (text/object/objects), and token usage. All callers updated. - RAG services: Accumulate token usage across all prompt calls (extract-concepts, edge-scoring, edge-reasoning, synthesis). Non-streaming path sends single combined response instead of chunk + end_of_session. - Agent orchestrator: UsageTracker accumulates tokens across meta-router, pattern prompt calls, and react reasoning. Attached to end_of_dialog. - Translators: Encode token fields when not None (is not None, not truthy) - Python SDK: RAG and text-completion methods return TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with token fields (streaming) - CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt, tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
134 lines
No EOL
5.2 KiB
Python
134 lines
No EOL
5.2 KiB
Python
"""
|
|
Unit test for DocumentRAG service parameter passing fix.
|
|
Tests that user and collection parameters from the message are correctly
|
|
passed to the DocumentRag.query() method.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, AsyncMock, patch, ANY
|
|
|
|
from trustgraph.retrieval.document_rag.rag import Processor
|
|
from trustgraph.schema import DocumentRagQuery, DocumentRagResponse
|
|
|
|
|
|
class TestDocumentRagService:
|
|
"""Test DocumentRAG service parameter passing"""
|
|
|
|
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
|
|
@pytest.mark.asyncio
|
|
async def test_user_and_collection_parameters_passed_to_query(self, mock_document_rag_class):
|
|
"""
|
|
Test that user and collection from message are passed to DocumentRag.query().
|
|
|
|
This is a regression test for the bug where user/collection parameters
|
|
were ignored, causing wrong collection names like 'd_trustgraph_default_384'
|
|
instead of 'd_my_user_test_coll_1_384'.
|
|
"""
|
|
# Setup processor
|
|
processor = Processor(
|
|
taskgroup=MagicMock(),
|
|
id="test-processor",
|
|
doc_limit=10
|
|
)
|
|
|
|
# Setup mock DocumentRag instance
|
|
mock_rag_instance = AsyncMock()
|
|
mock_document_rag_class.return_value = mock_rag_instance
|
|
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
|
|
|
|
# Setup message with custom user/collection
|
|
msg = MagicMock()
|
|
msg.value.return_value = DocumentRagQuery(
|
|
query="test query",
|
|
user="my_user", # Custom user (not default "trustgraph")
|
|
collection="test_coll_1", # Custom collection (not default "default")
|
|
doc_limit=5
|
|
)
|
|
msg.properties.return_value = {"id": "test-id"}
|
|
|
|
# Setup flow mock
|
|
consumer = MagicMock()
|
|
flow = MagicMock()
|
|
|
|
# Mock flow to return AsyncMock for clients and response producer
|
|
mock_producer = AsyncMock()
|
|
def flow_router(service_name):
|
|
if service_name == "response":
|
|
return mock_producer
|
|
return AsyncMock() # embeddings, doc-embeddings, prompt clients
|
|
flow.side_effect = flow_router
|
|
|
|
# Execute
|
|
await processor.on_request(msg, consumer, flow)
|
|
|
|
# Verify: DocumentRag.query was called with correct parameters
|
|
mock_rag_instance.query.assert_called_once_with(
|
|
"test query",
|
|
user="my_user", # Must be from message, not hardcoded default
|
|
collection="test_coll_1", # Must be from message, not hardcoded default
|
|
doc_limit=5,
|
|
explain_callback=ANY, # Explainability callback is always passed
|
|
save_answer_callback=ANY, # Librarian save callback is always passed
|
|
)
|
|
|
|
# Verify response was sent
|
|
mock_producer.send.assert_called_once()
|
|
sent_response = mock_producer.send.call_args[0][0]
|
|
assert isinstance(sent_response, DocumentRagResponse)
|
|
assert sent_response.response == "test response"
|
|
assert sent_response.error is None
|
|
|
|
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
|
|
@pytest.mark.asyncio
|
|
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_document_rag_class):
|
|
"""
|
|
Test that non-streaming mode sets end_of_stream=True in response.
|
|
|
|
This is a regression test for the bug where non-streaming responses
|
|
didn't set end_of_stream, causing clients to hang waiting for more data.
|
|
"""
|
|
# Setup processor
|
|
processor = Processor(
|
|
taskgroup=MagicMock(),
|
|
id="test-processor",
|
|
doc_limit=10
|
|
)
|
|
|
|
# Setup mock DocumentRag instance
|
|
mock_rag_instance = AsyncMock()
|
|
mock_document_rag_class.return_value = mock_rag_instance
|
|
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
|
|
|
|
# Setup message with non-streaming request
|
|
msg = MagicMock()
|
|
msg.value.return_value = DocumentRagQuery(
|
|
query="What is a cat?",
|
|
user="trustgraph",
|
|
collection="default",
|
|
doc_limit=10,
|
|
streaming=False # Non-streaming mode
|
|
)
|
|
msg.properties.return_value = {"id": "test-id"}
|
|
|
|
# Setup flow mock
|
|
consumer = MagicMock()
|
|
flow = MagicMock()
|
|
|
|
mock_producer = AsyncMock()
|
|
def flow_router(service_name):
|
|
if service_name == "response":
|
|
return mock_producer
|
|
return AsyncMock()
|
|
flow.side_effect = flow_router
|
|
|
|
# Execute
|
|
await processor.on_request(msg, consumer, flow)
|
|
|
|
# Verify: response was sent with end_of_stream=True
|
|
mock_producer.send.assert_called_once()
|
|
sent_response = mock_producer.send.call_args[0][0]
|
|
assert isinstance(sent_response, DocumentRagResponse)
|
|
assert sent_response.response == "A document about cats."
|
|
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
|
assert sent_response.end_of_session is True
|
|
assert sent_response.error is None |