trustgraph/tests/integration/test_document_rag_integration.py
cybermaggedon 14e49d83c7
Expose LLM token usage across all service layers (#782)
Expose LLM token usage (in_token, out_token, model) across all
service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
2026-04-13 14:38:34 +01:00

358 lines
15 KiB
Python

"""
Integration tests for DocumentRAG retrieval system
These tests verify the end-to-end functionality of the DocumentRAG system,
testing the coordination between embeddings, document retrieval, and prompt services.
Following the TEST_STRATEGY.md approach for integration testing.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
}
@pytest.mark.integration
class TestDocumentRagIntegration:
"""Integration tests for DocumentRAG system coordination"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client that returns realistic vector embeddings"""
client = AsyncMock()
# New batch format: [[[vectors_for_text1], ...]]
# One text input returns one vector set containing two vectors
client.embed.return_value = [
[
[0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text
[0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text
]
]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
# Returns ChunkMatch objects with chunk_id and score
client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.document_prompt.return_value = PromptResult(
response_type="text",
text=(
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
)
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@pytest.mark.asyncio
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test complete DocumentRAG pipeline from query to response"""
# Arrange
query = "What is machine learning?"
user = "test_user"
collection = "ml_knowledge"
doc_limit = 10
# Act
result = await document_rag.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit
)
# Assert - Verify service coordination
mock_embeddings_client.embed.assert_called_once_with([query])
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
)
# Documents are fetched from librarian using chunk_ids
mock_prompt_client.document_prompt.assert_called_once_with(
query=query,
documents=[
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
]
)
# Verify final response
result, usage = result
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
assert "artificial intelligence" in result.lower()
@pytest.mark.asyncio
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="I couldn't find any relevant documents for your query."
)
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act
result = await document_rag.query("very obscure query")
# Assert
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once_with(
query="very obscure query",
documents=[]
)
result_text, usage = result
assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when embeddings service fails"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Embeddings service unavailable" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_not_called()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when document service fails"""
# Arrange
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Document service connection failed" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when prompt service fails"""
# Arrange
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "LLM service rate limited" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_document_rag_with_different_document_limits(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG with various document limit configurations"""
# Test different document limits
test_cases = [1, 5, 10, 25, 50]
for limit in test_cases:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == limit
@pytest.mark.asyncio
async def test_document_rag_multi_user_isolation(self, document_rag, mock_doc_embeddings_client):
"""Test DocumentRAG properly isolates queries by user and collection"""
# Arrange
test_scenarios = [
("user1", "collection1"),
("user2", "collection2"),
("user1", "collection2"), # Same user, different collection
("user2", "collection1"), # Different user, same collection
]
for user, collection in test_scenarios:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(
f"query from {user} in {collection}",
user=user,
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk,
caplog):
"""Test DocumentRAG verbose logging functionality"""
import logging
# Arrange - Configure logging to capture debug messages
caplog.set_level(logging.DEBUG)
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
# Act
await document_rag.query("test query for verbose logging")
# Assert - Check for new logging messages
log_messages = caplog.text
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "chunks" in log_messages.lower()
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages
@pytest.mark.asyncio
@pytest.mark.slow
async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large chunk match set (100 chunks)
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_matches
# Act
import time
start_time = time.time()
result = await document_rag.query("performance test query", doc_limit=100)
end_time = time.time()
execution_time = end_time - start_time
# Assert
assert result is not None
assert execution_time < 5.0 # Should complete within 5 seconds
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == 100
@pytest.mark.asyncio
async def test_document_rag_default_parameters(self, document_rag, mock_doc_embeddings_client):
"""Test DocumentRAG uses correct default parameters"""
# Act
await document_rag.query("test query with defaults")
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20