mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Expose LLM token usage (in_token, out_token, model) across all service layers Propagate token counts from LLM services through the prompt, text-completion, graph-RAG, document-RAG, and agent orchestrator pipelines to the API gateway and Python SDK. All fields are Optional — None means "not available", distinguishing from a real zero count. Key changes: - Schema: Add in_token/out_token/model to TextCompletionResponse, PromptResponse, GraphRagResponse, DocumentRagResponse, AgentResponse - TextCompletionClient: New TextCompletionResult return type. Split into text_completion() (non-streaming) and text_completion_stream() (streaming with per-chunk handler callback) - PromptClient: New PromptResult with response_type (text/json/jsonl), typed fields (text/object/objects), and token usage. All callers updated. - RAG services: Accumulate token usage across all prompt calls (extract-concepts, edge-scoring, edge-reasoning, synthesis). Non-streaming path sends single combined response instead of chunk + end_of_session. - Agent orchestrator: UsageTracker accumulates tokens across meta-router, pattern prompt calls, and react reasoning. Attached to end_of_dialog. - Translators: Encode token fields when not None (is not None, not truthy) - Python SDK: RAG and text-completion methods return TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with token fields (streaming) - CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt, tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
358 lines
15 KiB
Python
358 lines
15 KiB
Python
"""
|
|
Integration tests for DocumentRAG retrieval system
|
|
|
|
These tests verify the end-to-end functionality of the DocumentRAG system,
|
|
testing the coordination between embeddings, document retrieval, and prompt services.
|
|
Following the TEST_STRATEGY.md approach for integration testing.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
|
from trustgraph.schema import ChunkMatch
|
|
from trustgraph.base import PromptResult
|
|
|
|
|
|
# Sample chunk content for testing - maps chunk_id to content
|
|
CHUNK_CONTENT = {
|
|
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
|
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
|
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
|
|
}
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestDocumentRagIntegration:
|
|
"""Integration tests for DocumentRAG system coordination"""
|
|
|
|
@pytest.fixture
|
|
def mock_embeddings_client(self):
|
|
"""Mock embeddings client that returns realistic vector embeddings"""
|
|
client = AsyncMock()
|
|
# New batch format: [[[vectors_for_text1], ...]]
|
|
# One text input returns one vector set containing two vectors
|
|
client.embed.return_value = [
|
|
[
|
|
[0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text
|
|
[0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text
|
|
]
|
|
]
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_doc_embeddings_client(self):
|
|
"""Mock document embeddings client that returns chunk matches"""
|
|
client = AsyncMock()
|
|
# Returns ChunkMatch objects with chunk_id and score
|
|
client.query.return_value = [
|
|
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
|
ChunkMatch(chunk_id="doc/c2", score=0.90),
|
|
ChunkMatch(chunk_id="doc/c3", score=0.85)
|
|
]
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_fetch_chunk(self):
|
|
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
|
async def fetch(chunk_id, user):
|
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
|
return fetch
|
|
|
|
@pytest.fixture
|
|
def mock_prompt_client(self):
|
|
"""Mock prompt client that generates realistic responses"""
|
|
client = AsyncMock()
|
|
client.document_prompt.return_value = PromptResult(
|
|
response_type="text",
|
|
text=(
|
|
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
|
"and improve from experience without being explicitly programmed. It uses algorithms "
|
|
"to find patterns in data and make predictions or decisions."
|
|
)
|
|
)
|
|
# Mock prompt() for extract-concepts call in DocumentRag
|
|
client.prompt.return_value = PromptResult(response_type="text", text="")
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
|
mock_prompt_client, mock_fetch_chunk):
|
|
"""Create DocumentRag instance with mocked dependencies"""
|
|
return DocumentRag(
|
|
embeddings_client=mock_embeddings_client,
|
|
doc_embeddings_client=mock_doc_embeddings_client,
|
|
prompt_client=mock_prompt_client,
|
|
fetch_chunk=mock_fetch_chunk,
|
|
verbose=True
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
|
mock_doc_embeddings_client, mock_prompt_client):
|
|
"""Test complete DocumentRAG pipeline from query to response"""
|
|
# Arrange
|
|
query = "What is machine learning?"
|
|
user = "test_user"
|
|
collection = "ml_knowledge"
|
|
doc_limit = 10
|
|
|
|
# Act
|
|
result = await document_rag.query(
|
|
query=query,
|
|
user=user,
|
|
collection=collection,
|
|
doc_limit=doc_limit
|
|
)
|
|
|
|
# Assert - Verify service coordination
|
|
mock_embeddings_client.embed.assert_called_once_with([query])
|
|
|
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
|
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
|
limit=doc_limit,
|
|
user=user,
|
|
collection=collection
|
|
)
|
|
|
|
# Documents are fetched from librarian using chunk_ids
|
|
mock_prompt_client.document_prompt.assert_called_once_with(
|
|
query=query,
|
|
documents=[
|
|
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
|
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
|
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
|
]
|
|
)
|
|
|
|
# Verify final response
|
|
result, usage = result
|
|
assert result is not None
|
|
assert isinstance(result, str)
|
|
assert "machine learning" in result.lower()
|
|
assert "artificial intelligence" in result.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
|
mock_doc_embeddings_client, mock_prompt_client,
|
|
mock_fetch_chunk):
|
|
"""Test DocumentRAG behavior when no documents are retrieved"""
|
|
# Arrange
|
|
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
|
mock_prompt_client.document_prompt.return_value = PromptResult(
|
|
response_type="text",
|
|
text="I couldn't find any relevant documents for your query."
|
|
)
|
|
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
|
|
|
document_rag = DocumentRag(
|
|
embeddings_client=mock_embeddings_client,
|
|
doc_embeddings_client=mock_doc_embeddings_client,
|
|
prompt_client=mock_prompt_client,
|
|
fetch_chunk=mock_fetch_chunk,
|
|
verbose=False
|
|
)
|
|
|
|
# Act
|
|
result = await document_rag.query("very obscure query")
|
|
|
|
# Assert
|
|
mock_embeddings_client.embed.assert_called_once()
|
|
mock_doc_embeddings_client.query.assert_called_once()
|
|
mock_prompt_client.document_prompt.assert_called_once_with(
|
|
query="very obscure query",
|
|
documents=[]
|
|
)
|
|
|
|
result_text, usage = result
|
|
assert result_text == "I couldn't find any relevant documents for your query."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
|
mock_doc_embeddings_client, mock_prompt_client,
|
|
mock_fetch_chunk):
|
|
"""Test DocumentRAG error handling when embeddings service fails"""
|
|
# Arrange
|
|
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
|
|
|
|
document_rag = DocumentRag(
|
|
embeddings_client=mock_embeddings_client,
|
|
doc_embeddings_client=mock_doc_embeddings_client,
|
|
prompt_client=mock_prompt_client,
|
|
fetch_chunk=mock_fetch_chunk,
|
|
verbose=False
|
|
)
|
|
|
|
# Act & Assert
|
|
with pytest.raises(Exception) as exc_info:
|
|
await document_rag.query("test query")
|
|
|
|
assert "Embeddings service unavailable" in str(exc_info.value)
|
|
mock_embeddings_client.embed.assert_called_once()
|
|
mock_doc_embeddings_client.query.assert_not_called()
|
|
mock_prompt_client.document_prompt.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
|
mock_doc_embeddings_client, mock_prompt_client,
|
|
mock_fetch_chunk):
|
|
"""Test DocumentRAG error handling when document service fails"""
|
|
# Arrange
|
|
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
|
|
|
|
document_rag = DocumentRag(
|
|
embeddings_client=mock_embeddings_client,
|
|
doc_embeddings_client=mock_doc_embeddings_client,
|
|
prompt_client=mock_prompt_client,
|
|
fetch_chunk=mock_fetch_chunk,
|
|
verbose=False
|
|
)
|
|
|
|
# Act & Assert
|
|
with pytest.raises(Exception) as exc_info:
|
|
await document_rag.query("test query")
|
|
|
|
assert "Document service connection failed" in str(exc_info.value)
|
|
mock_embeddings_client.embed.assert_called_once()
|
|
mock_doc_embeddings_client.query.assert_called_once()
|
|
mock_prompt_client.document_prompt.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
|
mock_doc_embeddings_client, mock_prompt_client,
|
|
mock_fetch_chunk):
|
|
"""Test DocumentRAG error handling when prompt service fails"""
|
|
# Arrange
|
|
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
|
|
|
|
document_rag = DocumentRag(
|
|
embeddings_client=mock_embeddings_client,
|
|
doc_embeddings_client=mock_doc_embeddings_client,
|
|
prompt_client=mock_prompt_client,
|
|
fetch_chunk=mock_fetch_chunk,
|
|
verbose=False
|
|
)
|
|
|
|
# Act & Assert
|
|
with pytest.raises(Exception) as exc_info:
|
|
await document_rag.query("test query")
|
|
|
|
assert "LLM service rate limited" in str(exc_info.value)
|
|
mock_embeddings_client.embed.assert_called_once()
|
|
mock_doc_embeddings_client.query.assert_called_once()
|
|
mock_prompt_client.document_prompt.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_with_different_document_limits(self, document_rag,
|
|
mock_doc_embeddings_client):
|
|
"""Test DocumentRAG with various document limit configurations"""
|
|
# Test different document limits
|
|
test_cases = [1, 5, 10, 25, 50]
|
|
|
|
for limit in test_cases:
|
|
# Reset mock call history
|
|
mock_doc_embeddings_client.reset_mock()
|
|
|
|
# Act
|
|
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
|
|
|
|
# Assert
|
|
mock_doc_embeddings_client.query.assert_called_once()
|
|
call_args = mock_doc_embeddings_client.query.call_args
|
|
assert call_args.kwargs['limit'] == limit
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_multi_user_isolation(self, document_rag, mock_doc_embeddings_client):
|
|
"""Test DocumentRAG properly isolates queries by user and collection"""
|
|
# Arrange
|
|
test_scenarios = [
|
|
("user1", "collection1"),
|
|
("user2", "collection2"),
|
|
("user1", "collection2"), # Same user, different collection
|
|
("user2", "collection1"), # Different user, same collection
|
|
]
|
|
|
|
for user, collection in test_scenarios:
|
|
# Reset mock call history
|
|
mock_doc_embeddings_client.reset_mock()
|
|
|
|
# Act
|
|
await document_rag.query(
|
|
f"query from {user} in {collection}",
|
|
user=user,
|
|
collection=collection
|
|
)
|
|
|
|
# Assert
|
|
mock_doc_embeddings_client.query.assert_called_once()
|
|
call_args = mock_doc_embeddings_client.query.call_args
|
|
assert call_args.kwargs['user'] == user
|
|
assert call_args.kwargs['collection'] == collection
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
|
mock_doc_embeddings_client, mock_prompt_client,
|
|
mock_fetch_chunk,
|
|
caplog):
|
|
"""Test DocumentRAG verbose logging functionality"""
|
|
import logging
|
|
|
|
# Arrange - Configure logging to capture debug messages
|
|
caplog.set_level(logging.DEBUG)
|
|
|
|
document_rag = DocumentRag(
|
|
embeddings_client=mock_embeddings_client,
|
|
doc_embeddings_client=mock_doc_embeddings_client,
|
|
prompt_client=mock_prompt_client,
|
|
fetch_chunk=mock_fetch_chunk,
|
|
verbose=True
|
|
)
|
|
|
|
# Act
|
|
await document_rag.query("test query for verbose logging")
|
|
|
|
# Assert - Check for new logging messages
|
|
log_messages = caplog.text
|
|
assert "DocumentRag initialized" in log_messages
|
|
assert "Constructing prompt..." in log_messages
|
|
assert "Computing embeddings..." in log_messages
|
|
assert "chunks" in log_messages.lower()
|
|
assert "Invoking LLM..." in log_messages
|
|
assert "Query processing complete" in log_messages
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.slow
|
|
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
|
mock_doc_embeddings_client):
|
|
"""Test DocumentRAG performance with large document retrieval"""
|
|
# Arrange - Mock large chunk match set (100 chunks)
|
|
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
|
|
mock_doc_embeddings_client.query.return_value = large_chunk_matches
|
|
|
|
# Act
|
|
import time
|
|
start_time = time.time()
|
|
|
|
result = await document_rag.query("performance test query", doc_limit=100)
|
|
|
|
end_time = time.time()
|
|
execution_time = end_time - start_time
|
|
|
|
# Assert
|
|
assert result is not None
|
|
assert execution_time < 5.0 # Should complete within 5 seconds
|
|
mock_doc_embeddings_client.query.assert_called_once()
|
|
call_args = mock_doc_embeddings_client.query.call_args
|
|
assert call_args.kwargs['limit'] == 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_document_rag_default_parameters(self, document_rag, mock_doc_embeddings_client):
|
|
"""Test DocumentRAG uses correct default parameters"""
|
|
# Act
|
|
await document_rag.query("test query with defaults")
|
|
|
|
# Assert
|
|
mock_doc_embeddings_client.query.assert_called_once()
|
|
call_args = mock_doc_embeddings_client.query.call_args
|
|
assert call_args.kwargs['user'] == "trustgraph"
|
|
assert call_args.kwargs['collection'] == "default"
|
|
assert call_args.kwargs['limit'] == 20
|