Fix tests (#666)

This commit is contained in:
cybermaggedon 2026-03-07 23:38:09 +00:00 committed by GitHub
parent 24bbe94136
commit 3bf8a65409
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 510 additions and 446 deletions

View file

@ -11,6 +11,14 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
}
@pytest.mark.integration
class TestDocumentRagIntegration:
"""Integration tests for DocumentRAG system coordination"""
@ -27,15 +35,19 @@ class TestDocumentRagIntegration:
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns realistic document chunks"""
"""Mock document embeddings client that returns chunk IDs"""
client = AsyncMock()
client.query.return_value = [
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
]
# Now returns chunk_ids instead of actual content
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
@ -48,17 +60,19 @@ class TestDocumentRagIntegration:
return client
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@pytest.mark.asyncio
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test complete DocumentRAG pipeline from query to response"""
# Arrange
@ -77,14 +91,15 @@ class TestDocumentRagIntegration:
# Assert - Verify service coordination
mock_embeddings_client.embed.assert_called_once_with(query)
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
)
# Documents are fetched from librarian using chunk_ids
mock_prompt_client.document_prompt.assert_called_once_with(
query=query,
documents=[
@ -101,17 +116,19 @@ class TestDocumentRagIntegration:
assert "artificial intelligence" in result.lower()
@pytest.mark.asyncio
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents found
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
@ -125,92 +142,98 @@ class TestDocumentRagIntegration:
query="very obscure query",
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when embeddings service fails"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Embeddings service unavailable" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_not_called()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when document service fails"""
# Arrange
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Document service connection failed" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when prompt service fails"""
# Arrange
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "LLM service rate limited" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_document_rag_with_different_document_limits(self, document_rag,
async def test_document_rag_with_different_document_limits(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG with various document limit configurations"""
# Test different document limits
test_cases = [1, 5, 10, 25, 50]
for limit in test_cases:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
@ -230,14 +253,14 @@ class TestDocumentRagIntegration:
for user, collection in test_scenarios:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(
f"query from {user} in {collection}",
user=user,
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
@ -245,19 +268,21 @@ class TestDocumentRagIntegration:
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk,
caplog):
"""Test DocumentRAG verbose logging functionality"""
import logging
# Arrange - Configure logging to capture debug messages
caplog.set_level(logging.DEBUG)
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@ -269,25 +294,25 @@ class TestDocumentRagIntegration:
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "Getting documents..." in log_messages
assert "chunk_ids" in log_messages.lower()
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages
@pytest.mark.asyncio
@pytest.mark.slow
async def test_document_rag_performance_with_large_document_set(self, document_rag,
async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large document set (100 documents)
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_doc_set
# Arrange - Mock large chunk_id set (100 chunks)
large_chunk_ids = [f"doc/c{i}" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_ids
# Act
import time
start_time = time.time()
result = await document_rag.query("performance test query", doc_limit=100)
end_time = time.time()
execution_time = end_time - start_time
@ -309,4 +334,4 @@ class TestDocumentRagIntegration:
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20
assert call_args.kwargs['limit'] == 20