mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix tests (#666)
This commit is contained in:
parent
24bbe94136
commit
3bf8a65409
10 changed files with 510 additions and 446 deletions
|
|
@ -11,6 +11,14 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
CHUNK_CONTENT = {
|
||||
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDocumentRagIntegration:
|
||||
"""Integration tests for DocumentRAG system coordination"""
|
||||
|
|
@ -27,15 +35,19 @@ class TestDocumentRagIntegration:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client that returns realistic document chunks"""
|
||||
"""Mock document embeddings client that returns chunk IDs"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
||||
]
|
||||
# Now returns chunk_ids instead of actual content
|
||||
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fetch_chunk(self):
|
||||
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||
async def fetch(chunk_id, user):
|
||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
return fetch
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
|
|
@ -48,17 +60,19 @@ class TestDocumentRagIntegration:
|
|||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||
mock_prompt_client, mock_fetch_chunk):
|
||||
"""Create DocumentRag instance with mocked dependencies"""
|
||||
return DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
||||
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test complete DocumentRAG pipeline from query to response"""
|
||||
# Arrange
|
||||
|
|
@ -77,14 +91,15 @@ class TestDocumentRagIntegration:
|
|||
|
||||
# Assert - Verify service coordination
|
||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
||||
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
|
||||
# Documents are fetched from librarian using chunk_ids
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query=query,
|
||||
documents=[
|
||||
|
|
@ -101,17 +116,19 @@ class TestDocumentRagIntegration:
|
|||
assert "artificial intelligence" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
|
@ -125,92 +142,98 @@ class TestDocumentRagIntegration:
|
|||
query="very obscure query",
|
||||
documents=[]
|
||||
)
|
||||
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG error handling when embeddings service fails"""
|
||||
# Arrange
|
||||
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
|
||||
assert "Embeddings service unavailable" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_not_called()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG error handling when document service fails"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
|
||||
assert "Document service connection failed" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG error handling when prompt service fails"""
|
||||
# Arrange
|
||||
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
|
||||
assert "LLM service rate limited" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_different_document_limits(self, document_rag,
|
||||
async def test_document_rag_with_different_document_limits(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG with various document limit configurations"""
|
||||
# Test different document limits
|
||||
test_cases = [1, 5, 10, 25, 50]
|
||||
|
||||
|
||||
for limit in test_cases:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
|
||||
# Act
|
||||
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
|
||||
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
|
|
@ -230,14 +253,14 @@ class TestDocumentRagIntegration:
|
|||
for user, collection in test_scenarios:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
|
||||
# Act
|
||||
await document_rag.query(
|
||||
f"query from {user} in {collection}",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
|
|
@ -245,19 +268,21 @@ class TestDocumentRagIntegration:
|
|||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk,
|
||||
caplog):
|
||||
"""Test DocumentRAG verbose logging functionality"""
|
||||
import logging
|
||||
|
||||
|
||||
# Arrange - Configure logging to capture debug messages
|
||||
caplog.set_level(logging.DEBUG)
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
|
@ -269,25 +294,25 @@ class TestDocumentRagIntegration:
|
|||
assert "DocumentRag initialized" in log_messages
|
||||
assert "Constructing prompt..." in log_messages
|
||||
assert "Computing embeddings..." in log_messages
|
||||
assert "Getting documents..." in log_messages
|
||||
assert "chunk_ids" in log_messages.lower()
|
||||
assert "Invoking LLM..." in log_messages
|
||||
assert "Query processing complete" in log_messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG performance with large document retrieval"""
|
||||
# Arrange - Mock large document set (100 documents)
|
||||
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_doc_set
|
||||
# Arrange - Mock large chunk_id set (100 chunks)
|
||||
large_chunk_ids = [f"doc/c{i}" for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_chunk_ids
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
result = await document_rag.query("performance test query", doc_limit=100)
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
|
|
@ -309,4 +334,4 @@ class TestDocumentRagIntegration:
|
|||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == "trustgraph"
|
||||
assert call_args.kwargs['collection'] == "default"
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue