Fix tests (#666)

This commit is contained in:
cybermaggedon 2026-03-07 23:38:09 +00:00 committed by GitHub
parent 24bbe94136
commit 3bf8a65409
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 510 additions and 446 deletions

View file

@ -25,13 +25,13 @@ class TestDocumentEmbeddingsRequestContract:
user="test_user",
collection="test_collection"
)
# Verify all expected fields exist
assert hasattr(request, 'vectors')
assert hasattr(request, 'limit')
assert hasattr(request, 'user')
assert hasattr(request, 'collection')
# Verify field values
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert request.limit == 10
@ -41,16 +41,16 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_translator_to_pulsar(self):
"""Test request translator converts dict to Pulsar schema"""
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2], [0.3, 0.4]],
"limit": 5,
"user": "custom_user",
"collection": "custom_collection"
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
assert result.limit == 5
@ -60,14 +60,14 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_translator_to_pulsar_with_defaults(self):
"""Test request translator uses correct defaults"""
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2]]
# No limit, user, or collection provided
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2]]
assert result.limit == 10 # Default
@ -77,16 +77,16 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_translator_from_pulsar(self):
"""Test request translator converts Pulsar schema to dict"""
translator = DocumentEmbeddingsRequestTranslator()
request = DocumentEmbeddingsRequest(
vectors=[[0.5, 0.6]],
limit=20,
user="test_user",
collection="test_collection"
)
result = translator.from_pulsar(request)
assert isinstance(result, dict)
assert result["vectors"] == [[0.5, 0.6]]
assert result["limit"] == 20
@ -99,19 +99,19 @@ class TestDocumentEmbeddingsResponseContract:
def test_response_schema_fields(self):
"""Test that DocumentEmbeddingsResponse has expected fields"""
# Create a response with chunks
# Create a response with chunk_ids
response = DocumentEmbeddingsResponse(
error=None,
chunks=["chunk1", "chunk2", "chunk3"]
chunk_ids=["chunk1", "chunk2", "chunk3"]
)
# Verify all expected fields exist
assert hasattr(response, 'error')
assert hasattr(response, 'chunks')
assert hasattr(response, 'chunk_ids')
# Verify field values
assert response.error is None
assert response.chunks == ["chunk1", "chunk2", "chunk3"]
assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"]
def test_response_schema_with_error(self):
"""Test response schema with error"""
@ -119,90 +119,79 @@ class TestDocumentEmbeddingsResponseContract:
type="query_error",
message="Database connection failed"
)
response = DocumentEmbeddingsResponse(
error=error,
chunks=None
chunk_ids=[]
)
assert response.error == error
assert response.chunks is None
def test_response_translator_from_pulsar_with_chunks(self):
"""Test response translator converts Pulsar schema with chunks to dict"""
assert response.error == error
assert response.chunk_ids == []
def test_response_translator_from_pulsar_with_chunk_ids(self):
"""Test response translator converts Pulsar schema with chunk_ids to dict"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunks=["doc1", "doc2", "doc3"]
chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["doc1", "doc2", "doc3"]
def test_response_translator_from_pulsar_with_bytes(self):
"""Test response translator handles byte chunks correctly"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = [b"byte_chunk1", b"byte_chunk2"]
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["byte_chunk1", "byte_chunk2"]
def test_response_translator_from_pulsar_with_empty_chunks(self):
"""Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = []
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == []
assert "chunk_ids" in result
assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"]
def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunks"""
def test_response_translator_from_pulsar_with_empty_chunk_ids(self):
"""Test response translator handles empty chunk_ids list"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = None
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=[]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" not in result or result.get("chunks") is None
assert "chunk_ids" in result
assert result["chunk_ids"] == []
def test_response_translator_from_pulsar_with_none_chunk_ids(self):
"""Test response translator handles None chunk_ids"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunk_ids = None
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunk_ids" not in result or result.get("chunk_ids") is None
def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunks=["chunk1", "chunk2"]
chunk_ids=["chunk1", "chunk2"]
)
result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["chunk1", "chunk2"]
assert "chunk_ids" in result
assert result["chunk_ids"] == ["chunk1", "chunk2"]
assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self):
"""Test that to_pulsar raises NotImplementedError for responses"""
translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunks": ["test"]})
translator.to_pulsar({"chunk_ids": ["test"]})
class TestDocumentEmbeddingsMessageCompatibility:
@ -217,26 +206,26 @@ class TestDocumentEmbeddingsMessageCompatibility:
"user": "test_user",
"collection": "test_collection"
}
# Convert to Pulsar request
req_translator = DocumentEmbeddingsRequestTranslator()
pulsar_request = req_translator.to_pulsar(request_data)
# Simulate service processing and creating response
response = DocumentEmbeddingsResponse(
error=None,
chunks=["relevant chunk 1", "relevant chunk 2"]
chunk_ids=["doc1/c1", "doc2/c2"]
)
# Convert response back to dict
resp_translator = DocumentEmbeddingsResponseTranslator()
response_data = resp_translator.from_pulsar(response)
# Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict)
assert "chunks" in response_data
assert len(response_data["chunks"]) == 2
assert "chunk_ids" in response_data
assert len(response_data["chunk_ids"]) == 2
def test_error_response_flow(self):
"""Test error response flow"""
@ -245,17 +234,18 @@ class TestDocumentEmbeddingsMessageCompatibility:
type="vector_db_error",
message="Collection not found"
)
response = DocumentEmbeddingsResponse(
error=error,
chunks=None
chunk_ids=[]
)
# Convert response to dict
translator = DocumentEmbeddingsResponseTranslator()
response_data = translator.from_pulsar(response)
# Verify error handling
assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunks
assert "chunks" not in response_data or response_data.get("chunks") is None
# The translator doesn't include error in the dict, only chunk_ids
assert "chunk_ids" in response_data
assert response_data["chunk_ids"] == []

View file

@ -11,6 +11,14 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
}
@pytest.mark.integration
class TestDocumentRagIntegration:
"""Integration tests for DocumentRAG system coordination"""
@ -27,15 +35,19 @@ class TestDocumentRagIntegration:
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns realistic document chunks"""
"""Mock document embeddings client that returns chunk IDs"""
client = AsyncMock()
client.query.return_value = [
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
]
# Now returns chunk_ids instead of actual content
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
@ -48,17 +60,19 @@ class TestDocumentRagIntegration:
return client
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@pytest.mark.asyncio
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test complete DocumentRAG pipeline from query to response"""
# Arrange
@ -77,14 +91,15 @@ class TestDocumentRagIntegration:
# Assert - Verify service coordination
mock_embeddings_client.embed.assert_called_once_with(query)
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
)
# Documents are fetched from librarian using chunk_ids
mock_prompt_client.document_prompt.assert_called_once_with(
query=query,
documents=[
@ -101,17 +116,19 @@ class TestDocumentRagIntegration:
assert "artificial intelligence" in result.lower()
@pytest.mark.asyncio
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents found
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
@ -125,92 +142,98 @@ class TestDocumentRagIntegration:
query="very obscure query",
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when embeddings service fails"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Embeddings service unavailable" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_not_called()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when document service fails"""
# Arrange
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Document service connection failed" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when prompt service fails"""
# Arrange
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "LLM service rate limited" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_document_rag_with_different_document_limits(self, document_rag,
async def test_document_rag_with_different_document_limits(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG with various document limit configurations"""
# Test different document limits
test_cases = [1, 5, 10, 25, 50]
for limit in test_cases:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
@ -230,14 +253,14 @@ class TestDocumentRagIntegration:
for user, collection in test_scenarios:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(
f"query from {user} in {collection}",
user=user,
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
@ -245,19 +268,21 @@ class TestDocumentRagIntegration:
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk,
caplog):
"""Test DocumentRAG verbose logging functionality"""
import logging
# Arrange - Configure logging to capture debug messages
caplog.set_level(logging.DEBUG)
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@ -269,25 +294,25 @@ class TestDocumentRagIntegration:
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "Getting documents..." in log_messages
assert "chunk_ids" in log_messages.lower()
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages
@pytest.mark.asyncio
@pytest.mark.slow
async def test_document_rag_performance_with_large_document_set(self, document_rag,
async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large document set (100 documents)
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_doc_set
# Arrange - Mock large chunk_id set (100 chunks)
large_chunk_ids = [f"doc/c{i}" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_ids
# Act
import time
start_time = time.time()
result = await document_rag.query("performance test query", doc_limit=100)
end_time = time.time()
execution_time = end_time - start_time
@ -309,4 +334,4 @@ class TestDocumentRagIntegration:
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20
assert call_args.kwargs['limit'] == 20

View file

@ -14,6 +14,14 @@ from tests.utils.streaming_assertions import (
)
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of AI.",
"doc/c2": "Deep learning uses neural networks.",
"doc/c3": "Supervised learning needs labeled data.",
}
@pytest.mark.integration
class TestDocumentRagStreaming:
"""Integration tests for DocumentRAG streaming"""
@ -27,15 +35,19 @@ class TestDocumentRagStreaming:
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client"""
"""Mock document embeddings client that returns chunk IDs"""
client = AsyncMock()
client.query.return_value = [
"Machine learning is a subset of AI.",
"Deep learning uses neural networks.",
"Supervised learning needs labeled data."
]
# Now returns chunk_ids instead of actual content
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
"""Mock prompt client with streaming support"""
@ -66,12 +78,13 @@ class TestDocumentRagStreaming:
@pytest.fixture
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client):
mock_streaming_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with streaming support"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@ -190,7 +203,7 @@ class TestDocumentRagStreaming:
mock_doc_embeddings_client):
"""Test streaming with no documents found"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids
callback = AsyncMock()
# Act

View file

@ -202,11 +202,18 @@ class TestDocumentRagStreamingProtocol:
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client"""
"""Mock document embeddings client that returns chunk IDs"""
client = AsyncMock()
client.query.return_value = ["doc1", "doc2"]
client.query.return_value = ["doc/c1", "doc/c2"]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return f"Content for {chunk_id}"
return fetch
@pytest.fixture
def mock_streaming_prompt_client(self):
"""Mock prompt client with streaming support"""
@ -227,12 +234,13 @@ class TestDocumentRagStreamingProtocol:
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client):
mock_streaming_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)

View file

@ -22,7 +22,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
# Mock the request method
client.request = AsyncMock(return_value=mock_response)
@ -75,7 +75,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = []
mock_response.chunk_ids = []
client.request = AsyncMock(return_value=mock_response)
@ -93,7 +93,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["test_chunk"]
mock_response.chunk_ids = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
@ -115,7 +115,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["chunk1"]
mock_response.chunk_ids = ["chunk1"]
client.request = AsyncMock(return_value=mock_response)
@ -136,7 +136,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["test_chunk"]
mock_response.chunk_ids = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)

View file

@ -77,9 +77,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results
mock_results = [
{"entity": {"doc": "First document chunk"}},
{"entity": {"doc": "Second document chunk"}},
{"entity": {"doc": "Third document chunk"}},
{"entity": {"chunk_id": "First document chunk"}},
{"entity": {"chunk_id": "Second document chunk"}},
{"entity": {"chunk_id": "Third document chunk"}},
]
processor.vecstore.search.return_value = mock_results
@ -108,11 +108,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"doc": "Document from first vector"}},
{"entity": {"doc": "Another doc from first vector"}},
{"entity": {"chunk_id": "Document from first vector"}},
{"entity": {"chunk_id": "Another doc from first vector"}},
]
mock_results_2 = [
{"entity": {"doc": "Document from second vector"}},
{"entity": {"chunk_id": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
@ -147,10 +147,10 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results - more results than limit
mock_results = [
{"entity": {"doc": "Document 1"}},
{"entity": {"doc": "Document 2"}},
{"entity": {"doc": "Document 3"}},
{"entity": {"doc": "Document 4"}},
{"entity": {"chunk_id": "Document 1"}},
{"entity": {"chunk_id": "Document 2"}},
{"entity": {"chunk_id": "Document 3"}},
{"entity": {"chunk_id": "Document 4"}},
]
processor.vecstore.search.return_value = mock_results
@ -217,9 +217,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with Unicode content
mock_results = [
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
{"entity": {"doc": "Regular ASCII document"}},
{"entity": {"doc": "Document with émojis: 😀🎉"}},
{"entity": {"chunk_id": "Document with Unicode: éñ中文🚀"}},
{"entity": {"chunk_id": "Regular ASCII document"}},
{"entity": {"chunk_id": "Document with émojis: 😀🎉"}},
]
processor.vecstore.search.return_value = mock_results
@ -244,8 +244,8 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with large content
large_doc = "A" * 10000 # 10KB of content
mock_results = [
{"entity": {"doc": large_doc}},
{"entity": {"doc": "Small document"}},
{"entity": {"chunk_id": large_doc}},
{"entity": {"chunk_id": "Small document"}},
]
processor.vecstore.search.return_value = mock_results
@ -268,9 +268,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with special characters
mock_results = [
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
{"entity": {"chunk_id": "Document with \"quotes\" and 'apostrophes'"}},
{"entity": {"chunk_id": "Document with\nnewlines\tand\ttabs"}},
{"entity": {"chunk_id": "Document with special chars: @#$%^&*()"}},
]
processor.vecstore.search.return_value = mock_results
@ -350,9 +350,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
result = await processor.query_document_embeddings(query)
@ -378,12 +378,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with duplicates across vectors
mock_results_1 = [
{"entity": {"doc": "Document A"}},
{"entity": {"doc": "Document B"}},
{"entity": {"chunk_id": "Document A"}},
{"entity": {"chunk_id": "Document B"}},
]
mock_results_2 = [
{"entity": {"doc": "Document B"}}, # Duplicate
{"entity": {"doc": "Document C"}},
{"entity": {"chunk_id": "Document B"}}, # Duplicate
{"entity": {"chunk_id": "Document C"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
@ -458,5 +458,5 @@ class TestMilvusDocEmbeddingsQueryProcessor:
mock_launch.assert_called_once_with(
default_ident,
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
"\nDocument embeddings query service. Input is vector, output is an array\nof chunk_ids\n"
)

View file

@ -77,9 +77,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'first document chunk'}
mock_point1.payload = {'chunk_id': 'first document chunk'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'second document chunk'}
mock_point2.payload = {'chunk_id': 'second document chunk'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
@ -132,11 +132,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query responses for different vectors
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from vector 1'}
mock_point1.payload = {'chunk_id': 'document from vector 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from vector 2'}
mock_point2.payload = {'chunk_id': 'document from vector 2'}
mock_point3 = MagicMock()
mock_point3.payload = {'doc': 'another document from vector 2'}
mock_point3.payload = {'chunk_id': 'another document from vector 2'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
@ -192,7 +192,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_points = []
for i in range(10):
mock_point = MagicMock()
mock_point.payload = {'doc': f'document chunk {i}'}
mock_point.payload = {'chunk_id': f'document chunk {i}'}
mock_points.append(mock_point)
mock_response = MagicMock()
@ -270,9 +270,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query responses
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from 2D vector'}
mock_point1.payload = {'chunk_id': 'document from 2D vector'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from 3D vector'}
mock_point2.payload = {'chunk_id': 'document from 3D vector'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
@ -326,9 +326,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response with UTF-8 content
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
mock_point1.payload = {'chunk_id': 'Document with UTF-8: café, naïve, résumé'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
mock_point2.payload = {'chunk_id': 'Chinese text: 你好世界'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
@ -399,7 +399,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response
mock_point = MagicMock()
mock_point.payload = {'doc': 'document chunk'}
mock_point.payload = {'chunk_id': 'document chunk'}
mock_response = MagicMock()
mock_response.points = [mock_point]
mock_qdrant_instance.query_points.return_value = mock_response
@ -442,9 +442,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response with fewer results than limit
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document 1'}
mock_point1.payload = {'chunk_id': 'document 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document 2'}
mock_point2.payload = {'chunk_id': 'document 2'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
@ -487,11 +487,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with missing 'doc' key
# Mock query response with missing 'chunk_id' key
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'valid document'}
mock_point1.payload = {'chunk_id': 'valid document'}
mock_point2 = MagicMock()
mock_point2.payload = {} # Missing 'doc' key
mock_point2.payload = {} # Missing 'chunk_id' key
mock_point3 = MagicMock()
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
@ -514,7 +514,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'payload_collection'
# Act & Assert
# This should raise a KeyError when trying to access payload['doc']
# This should raise a KeyError when trying to access payload['chunk_id']
with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message)

View file

@ -8,48 +8,75 @@ from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
# Sample chunk content mapping for tests
CHUNK_CONTENT = {
"doc/c1": "Document 1 content",
"doc/c2": "Document 2 content",
"doc/c3": "Relevant document content",
"doc/c4": "Another document",
"doc/c5": "Default doc",
"doc/c6": "Verbose test doc",
"doc/c7": "Verbose doc content",
"doc/ml1": "Machine learning is a subset of artificial intelligence...",
"doc/ml2": "ML algorithms learn patterns from data to make predictions...",
"doc/ml3": "Common ML techniques include supervised and unsupervised learning...",
}
@pytest.fixture
def mock_fetch_chunk():
"""Create a mock fetch_chunk function"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
class TestDocumentRag:
"""Test cases for DocumentRag class"""
def test_document_rag_initialization_with_defaults(self):
def test_document_rag_initialization_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag initialization with default verbose setting"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_doc_embeddings_client = MagicMock()
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk
)
# Verify initialization
assert document_rag.prompt_client == mock_prompt_client
assert document_rag.embeddings_client == mock_embeddings_client
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
assert document_rag.fetch_chunk == mock_fetch_chunk
assert document_rag.verbose is False # Default value
def test_document_rag_initialization_with_verbose(self):
def test_document_rag_initialization_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag initialization with verbose enabled"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_doc_embeddings_client = MagicMock()
# Initialize DocumentRag with verbose=True
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
# Verify initialization
assert document_rag.prompt_client == mock_prompt_client
assert document_rag.embeddings_client == mock_embeddings_client
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
assert document_rag.fetch_chunk == mock_fetch_chunk
assert document_rag.verbose is True
@ -60,7 +87,7 @@ class TestQuery:
"""Test Query initialization with default parameters"""
# Create mock DocumentRag
mock_rag = MagicMock()
# Initialize Query with defaults
query = Query(
rag=mock_rag,
@ -68,7 +95,7 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
@ -80,7 +107,7 @@ class TestQuery:
"""Test Query initialization with custom doc_limit"""
# Create mock DocumentRag
mock_rag = MagicMock()
# Initialize Query with custom doc_limit
query = Query(
rag=mock_rag,
@ -89,7 +116,7 @@ class TestQuery:
verbose=True,
doc_limit=50
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
@ -104,11 +131,11 @@ class TestQuery:
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method to return test vectors
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
# Initialize Query
query = Query(
rag=mock_rag,
@ -116,14 +143,14 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call get_vector
test_query = "What documents are relevant?"
result = await query.get_vector(test_query)
# Verify embeddings client was called correctly
mock_embeddings_client.embed.assert_called_once_with(test_query)
# Verify result matches expected vectors
assert result == expected_vectors
@ -136,15 +163,20 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk function
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock the embedding and document query responses
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
# Mock document results
test_docs = ["Document 1 content", "Document 2 content"]
mock_doc_embeddings_client.query.return_value = test_docs
# Mock document embeddings returns chunk_ids
test_chunk_ids = ["doc/c1", "doc/c2"]
mock_doc_embeddings_client.query.return_value = test_chunk_ids
# Initialize Query
query = Query(
rag=mock_rag,
@ -153,14 +185,14 @@ class TestQuery:
verbose=False,
doc_limit=15
)
# Call get_docs
test_query = "Find relevant documents"
result = await query.get_docs(test_query)
# Verify embeddings client was called
mock_embeddings_client.embed.assert_called_once_with(test_query)
# Verify doc embeddings client was called correctly
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
@ -168,35 +200,37 @@ class TestQuery:
user="test_user",
collection="test_collection"
)
# Verify result is list of documents
assert result == test_docs
# Verify result is list of fetched document content
assert "Document 1 content" in result
assert "Document 2 content" in result
@pytest.mark.asyncio
async def test_document_rag_query_method(self):
async def test_document_rag_query_method(self, mock_fetch_chunk):
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock embeddings and document responses
# Mock embeddings and document embeddings responses
test_vectors = [[0.1, 0.2, 0.3]]
test_docs = ["Relevant document content", "Another document"]
test_chunk_ids = ["doc/c3", "doc/c4"]
expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = test_vectors
mock_doc_embeddings_client.query.return_value = test_docs
mock_doc_embeddings_client.query.return_value = test_chunk_ids
mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query(
query="test query",
@ -204,10 +238,10 @@ class TestQuery:
collection="test_collection",
doc_limit=10
)
# Verify embeddings client was called
mock_embeddings_client.embed.assert_called_once_with("test query")
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
@ -215,39 +249,43 @@ class TestQuery:
user="test_user",
collection="test_collection"
)
# Verify prompt client was called with documents and query
mock_prompt_client.document_prompt.assert_called_once_with(
query="test query",
documents=test_docs
)
# Verify prompt client was called with fetched documents and query
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "test query"
# Documents should be fetched content, not chunk_ids
docs = call_args.kwargs["documents"]
assert "Relevant document content" in docs
assert "Another document" in docs
# Verify result
assert result == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self):
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag.query method with default parameters"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_doc_embeddings_client.query.return_value = ["Default doc"]
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk
)
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query")
# Verify default parameters were used
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2]],
@ -255,7 +293,7 @@ class TestQuery:
user="trustgraph", # Default user
collection="default" # Default collection
)
assert result == "Default response"
@pytest.mark.asyncio
@ -267,11 +305,16 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock responses
mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
mock_doc_embeddings_client.query.return_value = ["Verbose test doc"]
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
@ -280,49 +323,51 @@ class TestQuery:
verbose=True,
doc_limit=5
)
# Call get_docs
result = await query.get_docs("verbose test")
# Verify calls were made
mock_embeddings_client.embed.assert_called_once_with("verbose test")
mock_doc_embeddings_client.query.assert_called_once()
# Verify result
assert result == ["Verbose test doc"]
# Verify result contains fetched content
assert "Verbose test doc" in result
@pytest.mark.asyncio
async def test_document_rag_query_with_verbose(self):
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag.query method with verbose logging enabled"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses
mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
mock_doc_embeddings_client.query.return_value = ["Verbose doc content"]
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
# Call DocumentRag.query
result = await document_rag.query("verbose query test")
# Verify all clients were called
mock_embeddings_client.embed.assert_called_once_with("verbose query test")
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once_with(
query="verbose query test",
documents=["Verbose doc content"]
)
# Verify prompt client was called with fetched content
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response"
@pytest.mark.asyncio
@ -334,11 +379,16 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock responses - empty document list
# Mock fetch_chunk (won't be called if no chunk_ids)
async def mock_fetch(chunk_id, user):
return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty chunk_id list
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_doc_embeddings_client.query.return_value = [] # No documents found
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
# Initialize Query
query = Query(
rag=mock_rag,
@ -346,47 +396,48 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call get_docs
result = await query.get_docs("query with no results")
# Verify calls were made
mock_embeddings_client.embed.assert_called_once_with("query with no results")
mock_doc_embeddings_client.query.assert_called_once()
# Verify empty result is returned
assert result == []
@pytest.mark.asyncio
async def test_document_rag_query_with_empty_documents(self):
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
"""Test DocumentRag.query method when no documents are retrieved"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses - no documents found
# Mock responses - no chunk_ids found
mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
mock_doc_embeddings_client.query.return_value = [] # Empty document list
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
mock_prompt_client.document_prompt.return_value = "No documents found response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query("query with no matching docs")
# Verify prompt client was called with empty document list
mock_prompt_client.document_prompt.assert_called_once_with(
query="query with no matching docs",
documents=[]
)
assert result == "No documents found response"
@pytest.mark.asyncio
@ -396,11 +447,11 @@ class TestQuery:
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method
expected_vectors = [[0.9, 1.0, 1.1]]
mock_embeddings_client.embed.return_value = expected_vectors
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
@ -408,68 +459,71 @@ class TestQuery:
collection="test_collection",
verbose=True
)
# Call get_vector
result = await query.get_vector("verbose vector test")
# Verify embeddings client was called
mock_embeddings_client.embed.assert_called_once_with("verbose vector test")
# Verify result
assert result == expected_vectors
@pytest.mark.asyncio
async def test_document_rag_integration_flow(self):
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
"""Test complete DocumentRag integration with realistic data flow"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock realistic responses
query_text = "What is machine learning?"
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
retrieved_docs = [
"Machine learning is a subset of artificial intelligence...",
"ML algorithms learn patterns from data to make predictions...",
"Common ML techniques include supervised and unsupervised learning..."
]
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = query_vectors
mock_doc_embeddings_client.query.return_value = retrieved_docs
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Execute full pipeline
result = await document_rag.query(
query=query_text,
user="research_user",
user="research_user",
collection="ml_knowledge",
doc_limit=25
)
# Verify complete pipeline execution
mock_embeddings_client.embed.assert_called_once_with(query_text)
mock_doc_embeddings_client.query.assert_called_once_with(
query_vectors,
limit=25,
user="research_user",
collection="ml_knowledge"
)
mock_prompt_client.document_prompt.assert_called_once_with(
query=query_text,
documents=retrieved_docs
)
# Verify prompt client was called with fetched document content
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == query_text
# Verify documents were fetched from chunk_ids
docs = call_args.kwargs["documents"]
assert "Machine learning is a subset of artificial intelligence..." in docs
assert "ML algorithms learn patterns from data to make predictions..." in docs
assert "Common ML techniques include supervised and unsupervised learning..." in docs
# Verify final result
assert result == final_response
assert result == final_response

View file

@ -22,11 +22,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
chunk_id="This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
chunk_id="This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [chunk1, chunk2]
@ -84,7 +84,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
chunk_id="Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.chunks = [chunk]
@ -136,7 +136,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"",
chunk_id="",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
@ -148,51 +148,62 @@ class TestMilvusDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)"""
"""Test storing document embeddings with None chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=None,
chunk_id=None,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify no insert was called for None chunk
processor.vecstore.insert.assert_not_called()
# Note: Implementation passes through None chunk_ids (only skips empty string "")
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], None, 'test_user', 'test_collection'
)
@pytest.mark.asyncio
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
"""Test storing document embeddings with mix of valid and invalid chunks"""
"""Test storing document embeddings with mix of valid and empty chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings(
chunk=b"Valid document content",
chunk_id="Valid document content",
vectors=[[0.1, 0.2, 0.3]]
)
empty_chunk = ChunkEmbeddings(
chunk=b"",
chunk_id="",
vectors=[[0.4, 0.5, 0.6]]
)
none_chunk = ChunkEmbeddings(
chunk=None,
another_valid = ChunkEmbeddings(
chunk_id="Another valid chunk",
vectors=[[0.7, 0.8, 0.9]]
)
message.chunks = [valid_chunk, empty_chunk, none_chunk]
message.chunks = [valid_chunk, empty_chunk, another_valid]
await processor.store_document_embeddings(message)
# Verify only valid chunk was inserted with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection'
)
# Verify valid chunks were inserted, empty string chunk was skipped
expected_calls = [
([0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "Another valid chunk", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_chunk_id, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_chunk_id
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
@ -217,7 +228,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
chunk_id="Document with no vectors",
vectors=[]
)
message.chunks = [chunk]
@ -236,7 +247,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
chunk_id="Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
@ -264,46 +275,46 @@ class TestMilvusDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content"""
"""Test storing document embeddings with Unicode content in chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
chunk_id="chunk/doc/unicode-éñ中文🚀",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and inserted with user/collection parameters
# Verify Unicode chunk_id was stored correctly with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀", 'test_user', 'test_collection'
[0.1, 0.2, 0.3], "chunk/doc/unicode-éñ中文🚀", 'test_user', 'test_collection'
)
@pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor):
"""Test storing document embeddings with large document chunks"""
async def test_store_document_embeddings_large_chunk_id(self, processor):
"""Test storing document embeddings with long chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
large_content = "A" * 10000 # 10KB of content
# Create a long chunk_id
long_chunk_id = "chunk/doc/" + "a" * 200
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
chunk_id=long_chunk_id,
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify large content was inserted with user/collection parameters
# Verify long chunk_id was inserted with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], large_content, 'test_user', 'test_collection'
[0.1, 0.2, 0.3], long_chunk_id, 'test_user', 'test_collection'
)
@pytest.mark.asyncio
@ -315,7 +326,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b" \n\t ",
chunk_id=" \n\t ",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
@ -346,7 +357,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = collection
chunk = ChunkEmbeddings(
chunk=b"Test content",
chunk_id="Test content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
@ -367,7 +378,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message1.metadata.user = 'user1'
message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings(
chunk=b"User1 content",
chunk_id="User1 content",
vectors=[[0.1, 0.2, 0.3]]
)
message1.chunks = [chunk1]
@ -378,7 +389,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message2.metadata.user = 'user2'
message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings(
chunk=b"User2 content",
chunk_id="User2 content",
vectors=[[0.4, 0.5, 0.6]]
)
message2.chunks = [chunk2]
@ -409,7 +420,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test-collection.v1' # Collection with special chars
chunk = ChunkEmbeddings(
chunk=b"Special chars test",
chunk_id="Special chars test",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]

View file

@ -20,7 +20,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -34,7 +34,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Assert
# Verify QdrantClient was created with correct parameters
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@ -45,7 +45,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
@ -69,7 +69,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -86,13 +86,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test document chunk'
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
@ -100,18 +100,18 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Verify collection existence was checked (with dimension suffix)
expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3]
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3'
assert len(upsert_call_args[1]['points']) == 1
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['doc'] == 'test document chunk'
assert point.payload['chunk_id'] == 'doc/c1'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -123,7 +123,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -140,38 +140,38 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first document chunk'
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2]]
mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second document chunk'
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.3, 0.4]]
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should be called twice (once per chunk)
assert mock_qdrant_instance.upsert.call_count == 2
# Verify both chunks were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
# First chunk
first_call = upsert_calls[0]
first_point = first_call[1]['points'][0]
assert first_point.vector == [0.1, 0.2]
assert first_point.payload['doc'] == 'first document chunk'
assert first_point.payload['chunk_id'] == 'doc/c1'
# Second chunk
second_call = upsert_calls[1]
second_point = second_call[1]['points'][0]
assert second_point.vector == [0.3, 0.4]
assert second_point.payload['doc'] == 'second document chunk'
assert second_point.payload['chunk_id'] == 'doc/c2'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -183,7 +183,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -200,41 +200,41 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
mock_chunk.chunk_id = 'doc/multi-vector'
mock_chunk.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
for i, call in enumerate(upsert_calls):
point = call[1]['points'][0]
assert point.vector == expected_vectors[i]
assert point.payload['doc'] == 'multi-vector document chunk'
assert point.payload['chunk_id'] == 'doc/multi-vector'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client):
"""Test storing document embeddings skips empty chunks"""
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
"""Test storing document embeddings skips empty chunk_ids"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
@ -249,13 +249,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Create mock message with empty chunk
# Create mock message with empty chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
mock_chunk_empty.chunk_id = "" # Empty chunk_id
mock_chunk_empty.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk_empty]
@ -264,9 +264,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message)
# Assert
# Should not call upsert for empty chunks
# Should not call upsert for empty chunk_ids
mock_qdrant_instance.upsert.assert_not_called()
# collection_exists should NOT be called since we return early for empty chunks
# collection_exists should NOT be called since we return early for empty chunk_ids
mock_qdrant_instance.collection_exists.assert_not_called()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@ -298,7 +298,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
mock_message.chunks = [mock_chunk]
@ -350,7 +350,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
@ -388,7 +388,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first chunk'
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
mock_message1.chunks = [mock_chunk1]
@ -406,7 +406,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second chunk'
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
mock_message2.chunks = [mock_chunk2]
@ -452,7 +452,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'dim_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
mock_chunk.chunk_id = 'doc/dim-test'
mock_chunk.vectors = [
[0.1, 0.2], # 2 dimensions
[0.3, 0.4, 0.5] # 3 dimensions
@ -485,28 +485,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Arrange
mock_qdrant_client.return_value = MagicMock()
mock_parser = MagicMock()
# Act
with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
Processor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
# Verify processor-specific arguments were added
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client):
"""Test proper UTF-8 decoding of chunk text"""
async def test_chunk_id_with_special_characters(self, mock_uuid, mock_qdrant_client):
"""Test storing chunk_id with special characters (URIs)"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -517,65 +517,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('utf8_user', 'utf8_collection')] = {}
processor.known_collections[('uri_user', 'uri_collection')] = {}
# Create mock message with UTF-8 encoded text
# Create mock message with URI-style chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'utf8_user'
mock_message.metadata.collection = 'utf8_collection'
mock_message.metadata.user = 'uri_user'
mock_message.metadata.collection = 'uri_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé'
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Verify chunk.decode was called with 'utf-8'
mock_chunk.chunk.decode.assert_called_with('utf-8')
# Verify the decoded text was stored in payload
# Verify the chunk_id was stored correctly
upsert_call_args = mock_qdrant_instance.upsert.call_args
point = upsert_call_args[1]['points'][0]
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
async def test_chunk_decode_exception_handling(self, mock_qdrant_client):
"""Test handling of chunk decode exceptions"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('decode_user', 'decode_collection')] = {}
# Create mock message with decode error
mock_message = MagicMock()
mock_message.metadata.user = 'decode_user'
mock_message.metadata.collection = 'decode_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte')
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(UnicodeDecodeError):
await processor.store_document_embeddings(mock_message)
assert point.payload['chunk_id'] == 'https://trustgraph.ai/doc/my-document/p1/c3'
if __name__ == '__main__':
pytest.main([__file__])
pytest.main([__file__])