Fix tests (#666)

This commit is contained in:
cybermaggedon 2026-03-07 23:38:09 +00:00 committed by GitHub
parent 24bbe94136
commit 3bf8a65409
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 510 additions and 446 deletions

View file

@ -99,19 +99,19 @@ class TestDocumentEmbeddingsResponseContract:
def test_response_schema_fields(self): def test_response_schema_fields(self):
"""Test that DocumentEmbeddingsResponse has expected fields""" """Test that DocumentEmbeddingsResponse has expected fields"""
# Create a response with chunks # Create a response with chunk_ids
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunks=["chunk1", "chunk2", "chunk3"] chunk_ids=["chunk1", "chunk2", "chunk3"]
) )
# Verify all expected fields exist # Verify all expected fields exist
assert hasattr(response, 'error') assert hasattr(response, 'error')
assert hasattr(response, 'chunks') assert hasattr(response, 'chunk_ids')
# Verify field values # Verify field values
assert response.error is None assert response.error is None
assert response.chunks == ["chunk1", "chunk2", "chunk3"] assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"]
def test_response_schema_with_error(self): def test_response_schema_with_error(self):
"""Test response schema with error""" """Test response schema with error"""
@ -122,64 +122,53 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=error, error=error,
chunks=None chunk_ids=[]
) )
assert response.error == error assert response.error == error
assert response.chunks is None assert response.chunk_ids == []
def test_response_translator_from_pulsar_with_chunks(self): def test_response_translator_from_pulsar_with_chunk_ids(self):
"""Test response translator converts Pulsar schema with chunks to dict""" """Test response translator converts Pulsar schema with chunk_ids to dict"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunks=["doc1", "doc2", "doc3"] chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"]
) )
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunks" in result assert "chunk_ids" in result
assert result["chunks"] == ["doc1", "doc2", "doc3"] assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"]
def test_response_translator_from_pulsar_with_bytes(self): def test_response_translator_from_pulsar_with_empty_chunk_ids(self):
"""Test response translator handles byte chunks correctly""" """Test response translator handles empty chunk_ids list"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock() response = DocumentEmbeddingsResponse(
response.chunks = [b"byte_chunk1", b"byte_chunk2"] error=None,
chunk_ids=[]
)
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunks" in result assert "chunk_ids" in result
assert result["chunks"] == ["byte_chunk1", "byte_chunk2"] assert result["chunk_ids"] == []
def test_response_translator_from_pulsar_with_empty_chunks(self): def test_response_translator_from_pulsar_with_none_chunk_ids(self):
"""Test response translator handles empty chunks list""" """Test response translator handles None chunk_ids"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock() response = MagicMock()
response.chunks = [] response.chunk_ids = None
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunks" in result assert "chunk_ids" not in result or result.get("chunk_ids") is None
assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = None
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self): def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag""" """Test response translator with completion flag"""
@ -187,14 +176,14 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunks=["chunk1", "chunk2"] chunk_ids=["chunk1", "chunk2"]
) )
result, is_final = translator.from_response_with_completion(response) result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunks" in result assert "chunk_ids" in result
assert result["chunks"] == ["chunk1", "chunk2"] assert result["chunk_ids"] == ["chunk1", "chunk2"]
assert is_final is True # Document embeddings responses are always final assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self): def test_response_translator_to_pulsar_not_implemented(self):
@ -202,7 +191,7 @@ class TestDocumentEmbeddingsResponseContract:
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunks": ["test"]}) translator.to_pulsar({"chunk_ids": ["test"]})
class TestDocumentEmbeddingsMessageCompatibility: class TestDocumentEmbeddingsMessageCompatibility:
@ -225,7 +214,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Simulate service processing and creating response # Simulate service processing and creating response
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunks=["relevant chunk 1", "relevant chunk 2"] chunk_ids=["doc1/c1", "doc2/c2"]
) )
# Convert response back to dict # Convert response back to dict
@ -235,8 +224,8 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify data integrity # Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest) assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict) assert isinstance(response_data, dict)
assert "chunks" in response_data assert "chunk_ids" in response_data
assert len(response_data["chunks"]) == 2 assert len(response_data["chunk_ids"]) == 2
def test_error_response_flow(self): def test_error_response_flow(self):
"""Test error response flow""" """Test error response flow"""
@ -248,7 +237,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=error, error=error,
chunks=None chunk_ids=[]
) )
# Convert response to dict # Convert response to dict
@ -257,5 +246,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify error handling # Verify error handling
assert isinstance(response_data, dict) assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunks # The translator doesn't include error in the dict, only chunk_ids
assert "chunks" not in response_data or response_data.get("chunks") is None assert "chunk_ids" in response_data
assert response_data["chunk_ids"] == []

View file

@ -11,6 +11,14 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
}
@pytest.mark.integration @pytest.mark.integration
class TestDocumentRagIntegration: class TestDocumentRagIntegration:
"""Integration tests for DocumentRAG system coordination""" """Integration tests for DocumentRAG system coordination"""
@ -27,15 +35,19 @@ class TestDocumentRagIntegration:
@pytest.fixture @pytest.fixture
def mock_doc_embeddings_client(self): def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns realistic document chunks""" """Mock document embeddings client that returns chunk IDs"""
client = AsyncMock() client = AsyncMock()
client.query.return_value = [ # Now returns chunk_ids instead of actual content
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.", client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
]
return client return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture @pytest.fixture
def mock_prompt_client(self): def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses""" """Mock prompt client that generates realistic responses"""
@ -48,12 +60,14 @@ class TestDocumentRagIntegration:
return client return client
@pytest.fixture @pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client): def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies""" """Create DocumentRag instance with mocked dependencies"""
return DocumentRag( return DocumentRag(
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True verbose=True
) )
@ -85,6 +99,7 @@ class TestDocumentRagIntegration:
collection=collection collection=collection
) )
# Documents are fetched from librarian using chunk_ids
mock_prompt_client.document_prompt.assert_called_once_with( mock_prompt_client.document_prompt.assert_called_once_with(
query=query, query=query,
documents=[ documents=[
@ -102,16 +117,18 @@ class TestDocumentRagIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client, async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client): mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG behavior when no documents are retrieved""" """Test DocumentRAG behavior when no documents are retrieved"""
# Arrange # Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents found mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query." mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
document_rag = DocumentRag( document_rag = DocumentRag(
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )
@ -130,7 +147,8 @@ class TestDocumentRagIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client): mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when embeddings service fails""" """Test DocumentRAG error handling when embeddings service fails"""
# Arrange # Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable") mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
@ -139,6 +157,7 @@ class TestDocumentRagIntegration:
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )
@ -153,7 +172,8 @@ class TestDocumentRagIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_document_service_failure(self, mock_embeddings_client, async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client): mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when document service fails""" """Test DocumentRAG error handling when document service fails"""
# Arrange # Arrange
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed") mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
@ -162,6 +182,7 @@ class TestDocumentRagIntegration:
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )
@ -176,7 +197,8 @@ class TestDocumentRagIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client, async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client): mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when prompt service fails""" """Test DocumentRAG error handling when prompt service fails"""
# Arrange # Arrange
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited") mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
@ -185,6 +207,7 @@ class TestDocumentRagIntegration:
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )
@ -247,6 +270,7 @@ class TestDocumentRagIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_verbose_logging(self, mock_embeddings_client, async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client, mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk,
caplog): caplog):
"""Test DocumentRAG verbose logging functionality""" """Test DocumentRAG verbose logging functionality"""
import logging import logging
@ -258,6 +282,7 @@ class TestDocumentRagIntegration:
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True verbose=True
) )
@ -269,7 +294,7 @@ class TestDocumentRagIntegration:
assert "DocumentRag initialized" in log_messages assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages assert "Computing embeddings..." in log_messages
assert "Getting documents..." in log_messages assert "chunk_ids" in log_messages.lower()
assert "Invoking LLM..." in log_messages assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages assert "Query processing complete" in log_messages
@ -278,9 +303,9 @@ class TestDocumentRagIntegration:
async def test_document_rag_performance_with_large_document_set(self, document_rag, async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client): mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval""" """Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large document set (100 documents) # Arrange - Mock large chunk_id set (100 chunks)
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)] large_chunk_ids = [f"doc/c{i}" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_doc_set mock_doc_embeddings_client.query.return_value = large_chunk_ids
# Act # Act
import time import time

View file

@ -14,6 +14,14 @@ from tests.utils.streaming_assertions import (
) )
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of AI.",
"doc/c2": "Deep learning uses neural networks.",
"doc/c3": "Supervised learning needs labeled data.",
}
@pytest.mark.integration @pytest.mark.integration
class TestDocumentRagStreaming: class TestDocumentRagStreaming:
"""Integration tests for DocumentRAG streaming""" """Integration tests for DocumentRAG streaming"""
@ -27,15 +35,19 @@ class TestDocumentRagStreaming:
@pytest.fixture @pytest.fixture
def mock_doc_embeddings_client(self): def mock_doc_embeddings_client(self):
"""Mock document embeddings client""" """Mock document embeddings client that returns chunk IDs"""
client = AsyncMock() client = AsyncMock()
client.query.return_value = [ # Now returns chunk_ids instead of actual content
"Machine learning is a subset of AI.", client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
"Deep learning uses neural networks.",
"Supervised learning needs labeled data."
]
return client return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture @pytest.fixture
def mock_streaming_prompt_client(self, mock_streaming_llm_response): def mock_streaming_prompt_client(self, mock_streaming_llm_response):
"""Mock prompt client with streaming support""" """Mock prompt client with streaming support"""
@ -66,12 +78,13 @@ class TestDocumentRagStreaming:
@pytest.fixture @pytest.fixture
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client, def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client): mock_streaming_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with streaming support""" """Create DocumentRag instance with streaming support"""
return DocumentRag( return DocumentRag(
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client, prompt_client=mock_streaming_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True verbose=True
) )
@ -190,7 +203,7 @@ class TestDocumentRagStreaming:
mock_doc_embeddings_client): mock_doc_embeddings_client):
"""Test streaming with no documents found""" """Test streaming with no documents found"""
# Arrange # Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents mock_doc_embeddings_client.query.return_value = [] # No chunk_ids
callback = AsyncMock() callback = AsyncMock()
# Act # Act

View file

@ -202,11 +202,18 @@ class TestDocumentRagStreamingProtocol:
@pytest.fixture @pytest.fixture
def mock_doc_embeddings_client(self): def mock_doc_embeddings_client(self):
"""Mock document embeddings client""" """Mock document embeddings client that returns chunk IDs"""
client = AsyncMock() client = AsyncMock()
client.query.return_value = ["doc1", "doc2"] client.query.return_value = ["doc/c1", "doc/c2"]
return client return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return f"Content for {chunk_id}"
return fetch
@pytest.fixture @pytest.fixture
def mock_streaming_prompt_client(self): def mock_streaming_prompt_client(self):
"""Mock prompt client with streaming support""" """Mock prompt client with streaming support"""
@ -227,12 +234,13 @@ class TestDocumentRagStreamingProtocol:
@pytest.fixture @pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client): mock_streaming_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies""" """Create DocumentRag instance with mocked dependencies"""
return DocumentRag( return DocumentRag(
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client, prompt_client=mock_streaming_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )

View file

@ -22,7 +22,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunks = ["chunk1", "chunk2", "chunk3"] mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
# Mock the request method # Mock the request method
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
@ -75,7 +75,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunks = [] mock_response.chunk_ids = []
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
@ -93,7 +93,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunks = ["test_chunk"] mock_response.chunk_ids = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
@ -115,7 +115,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunks = ["chunk1"] mock_response.chunk_ids = ["chunk1"]
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
@ -136,7 +136,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunks = ["test_chunk"] mock_response.chunk_ids = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)

View file

@ -77,9 +77,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results # Mock search results
mock_results = [ mock_results = [
{"entity": {"doc": "First document chunk"}}, {"entity": {"chunk_id": "First document chunk"}},
{"entity": {"doc": "Second document chunk"}}, {"entity": {"chunk_id": "Second document chunk"}},
{"entity": {"doc": "Third document chunk"}}, {"entity": {"chunk_id": "Third document chunk"}},
] ]
processor.vecstore.search.return_value = mock_results processor.vecstore.search.return_value = mock_results
@ -108,11 +108,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results - different results for each vector # Mock search results - different results for each vector
mock_results_1 = [ mock_results_1 = [
{"entity": {"doc": "Document from first vector"}}, {"entity": {"chunk_id": "Document from first vector"}},
{"entity": {"doc": "Another doc from first vector"}}, {"entity": {"chunk_id": "Another doc from first vector"}},
] ]
mock_results_2 = [ mock_results_2 = [
{"entity": {"doc": "Document from second vector"}}, {"entity": {"chunk_id": "Document from second vector"}},
] ]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
@ -147,10 +147,10 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results - more results than limit # Mock search results - more results than limit
mock_results = [ mock_results = [
{"entity": {"doc": "Document 1"}}, {"entity": {"chunk_id": "Document 1"}},
{"entity": {"doc": "Document 2"}}, {"entity": {"chunk_id": "Document 2"}},
{"entity": {"doc": "Document 3"}}, {"entity": {"chunk_id": "Document 3"}},
{"entity": {"doc": "Document 4"}}, {"entity": {"chunk_id": "Document 4"}},
] ]
processor.vecstore.search.return_value = mock_results processor.vecstore.search.return_value = mock_results
@ -217,9 +217,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with Unicode content # Mock search results with Unicode content
mock_results = [ mock_results = [
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}}, {"entity": {"chunk_id": "Document with Unicode: éñ中文🚀"}},
{"entity": {"doc": "Regular ASCII document"}}, {"entity": {"chunk_id": "Regular ASCII document"}},
{"entity": {"doc": "Document with émojis: 😀🎉"}}, {"entity": {"chunk_id": "Document with émojis: 😀🎉"}},
] ]
processor.vecstore.search.return_value = mock_results processor.vecstore.search.return_value = mock_results
@ -244,8 +244,8 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with large content # Mock search results with large content
large_doc = "A" * 10000 # 10KB of content large_doc = "A" * 10000 # 10KB of content
mock_results = [ mock_results = [
{"entity": {"doc": large_doc}}, {"entity": {"chunk_id": large_doc}},
{"entity": {"doc": "Small document"}}, {"entity": {"chunk_id": "Small document"}},
] ]
processor.vecstore.search.return_value = mock_results processor.vecstore.search.return_value = mock_results
@ -268,9 +268,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with special characters # Mock search results with special characters
mock_results = [ mock_results = [
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}}, {"entity": {"chunk_id": "Document with \"quotes\" and 'apostrophes'"}},
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}}, {"entity": {"chunk_id": "Document with\nnewlines\tand\ttabs"}},
{"entity": {"doc": "Document with special chars: @#$%^&*()"}}, {"entity": {"chunk_id": "Document with special chars: @#$%^&*()"}},
] ]
processor.vecstore.search.return_value = mock_results processor.vecstore.search.return_value = mock_results
@ -350,9 +350,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
) )
# Mock search results for each vector # Mock search results for each vector
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}] mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}] mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}] mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
@ -378,12 +378,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
# Mock search results with duplicates across vectors # Mock search results with duplicates across vectors
mock_results_1 = [ mock_results_1 = [
{"entity": {"doc": "Document A"}}, {"entity": {"chunk_id": "Document A"}},
{"entity": {"doc": "Document B"}}, {"entity": {"chunk_id": "Document B"}},
] ]
mock_results_2 = [ mock_results_2 = [
{"entity": {"doc": "Document B"}}, # Duplicate {"entity": {"chunk_id": "Document B"}}, # Duplicate
{"entity": {"doc": "Document C"}}, {"entity": {"chunk_id": "Document C"}},
] ]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
@ -458,5 +458,5 @@ class TestMilvusDocEmbeddingsQueryProcessor:
mock_launch.assert_called_once_with( mock_launch.assert_called_once_with(
default_ident, default_ident,
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n" "\nDocument embeddings query service. Input is vector, output is an array\nof chunk_ids\n"
) )

View file

@ -77,9 +77,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response # Mock query response
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'first document chunk'} mock_point1.payload = {'chunk_id': 'first document chunk'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'second document chunk'} mock_point2.payload = {'chunk_id': 'second document chunk'}
mock_response = MagicMock() mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2] mock_response.points = [mock_point1, mock_point2]
@ -132,11 +132,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query responses for different vectors # Mock query responses for different vectors
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from vector 1'} mock_point1.payload = {'chunk_id': 'document from vector 1'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from vector 2'} mock_point2.payload = {'chunk_id': 'document from vector 2'}
mock_point3 = MagicMock() mock_point3 = MagicMock()
mock_point3.payload = {'doc': 'another document from vector 2'} mock_point3.payload = {'chunk_id': 'another document from vector 2'}
mock_response1 = MagicMock() mock_response1 = MagicMock()
mock_response1.points = [mock_point1] mock_response1.points = [mock_point1]
@ -192,7 +192,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_points = [] mock_points = []
for i in range(10): for i in range(10):
mock_point = MagicMock() mock_point = MagicMock()
mock_point.payload = {'doc': f'document chunk {i}'} mock_point.payload = {'chunk_id': f'document chunk {i}'}
mock_points.append(mock_point) mock_points.append(mock_point)
mock_response = MagicMock() mock_response = MagicMock()
@ -270,9 +270,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query responses # Mock query responses
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from 2D vector'} mock_point1.payload = {'chunk_id': 'document from 2D vector'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from 3D vector'} mock_point2.payload = {'chunk_id': 'document from 3D vector'}
mock_response1 = MagicMock() mock_response1 = MagicMock()
mock_response1.points = [mock_point1] mock_response1.points = [mock_point1]
@ -326,9 +326,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response with UTF-8 content # Mock query response with UTF-8 content
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'} mock_point1.payload = {'chunk_id': 'Document with UTF-8: café, naïve, résumé'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'Chinese text: 你好世界'} mock_point2.payload = {'chunk_id': 'Chinese text: 你好世界'}
mock_response = MagicMock() mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2] mock_response.points = [mock_point1, mock_point2]
@ -399,7 +399,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response # Mock query response
mock_point = MagicMock() mock_point = MagicMock()
mock_point.payload = {'doc': 'document chunk'} mock_point.payload = {'chunk_id': 'document chunk'}
mock_response = MagicMock() mock_response = MagicMock()
mock_response.points = [mock_point] mock_response.points = [mock_point]
mock_qdrant_instance.query_points.return_value = mock_response mock_qdrant_instance.query_points.return_value = mock_response
@ -442,9 +442,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response with fewer results than limit # Mock query response with fewer results than limit
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document 1'} mock_point1.payload = {'chunk_id': 'document 1'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document 2'} mock_point2.payload = {'chunk_id': 'document 2'}
mock_response = MagicMock() mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2] mock_response.points = [mock_point1, mock_point2]
@ -487,11 +487,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with missing 'doc' key # Mock query response with missing 'chunk_id' key
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'valid document'} mock_point1.payload = {'chunk_id': 'valid document'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {} # Missing 'doc' key mock_point2.payload = {} # Missing 'chunk_id' key
mock_point3 = MagicMock() mock_point3 = MagicMock()
mock_point3.payload = {'other_key': 'invalid'} # Wrong key mock_point3.payload = {'other_key': 'invalid'} # Wrong key
@ -514,7 +514,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_message.collection = 'payload_collection' mock_message.collection = 'payload_collection'
# Act & Assert # Act & Assert
# This should raise a KeyError when trying to access payload['doc'] # This should raise a KeyError when trying to access payload['chunk_id']
with pytest.raises(KeyError): with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message) await processor.query_document_embeddings(mock_message)

View file

@ -8,10 +8,33 @@ from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
# Sample chunk content mapping for tests
CHUNK_CONTENT = {
"doc/c1": "Document 1 content",
"doc/c2": "Document 2 content",
"doc/c3": "Relevant document content",
"doc/c4": "Another document",
"doc/c5": "Default doc",
"doc/c6": "Verbose test doc",
"doc/c7": "Verbose doc content",
"doc/ml1": "Machine learning is a subset of artificial intelligence...",
"doc/ml2": "ML algorithms learn patterns from data to make predictions...",
"doc/ml3": "Common ML techniques include supervised and unsupervised learning...",
}
@pytest.fixture
def mock_fetch_chunk():
"""Create a mock fetch_chunk function"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
class TestDocumentRag: class TestDocumentRag:
"""Test cases for DocumentRag class""" """Test cases for DocumentRag class"""
def test_document_rag_initialization_with_defaults(self): def test_document_rag_initialization_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag initialization with default verbose setting""" """Test DocumentRag initialization with default verbose setting"""
# Create mock clients # Create mock clients
mock_prompt_client = MagicMock() mock_prompt_client = MagicMock()
@ -22,16 +45,18 @@ class TestDocumentRag:
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk
) )
# Verify initialization # Verify initialization
assert document_rag.prompt_client == mock_prompt_client assert document_rag.prompt_client == mock_prompt_client
assert document_rag.embeddings_client == mock_embeddings_client assert document_rag.embeddings_client == mock_embeddings_client
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
assert document_rag.fetch_chunk == mock_fetch_chunk
assert document_rag.verbose is False # Default value assert document_rag.verbose is False # Default value
def test_document_rag_initialization_with_verbose(self): def test_document_rag_initialization_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag initialization with verbose enabled""" """Test DocumentRag initialization with verbose enabled"""
# Create mock clients # Create mock clients
mock_prompt_client = MagicMock() mock_prompt_client = MagicMock()
@ -43,6 +68,7 @@ class TestDocumentRag:
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=True verbose=True
) )
@ -50,6 +76,7 @@ class TestDocumentRag:
assert document_rag.prompt_client == mock_prompt_client assert document_rag.prompt_client == mock_prompt_client
assert document_rag.embeddings_client == mock_embeddings_client assert document_rag.embeddings_client == mock_embeddings_client
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
assert document_rag.fetch_chunk == mock_fetch_chunk
assert document_rag.verbose is True assert document_rag.verbose is True
@ -137,13 +164,18 @@ class TestQuery:
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk function
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock the embedding and document query responses # Mock the embedding and document query responses
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors mock_embeddings_client.embed.return_value = test_vectors
# Mock document results # Mock document embeddings returns chunk_ids
test_docs = ["Document 1 content", "Document 2 content"] test_chunk_ids = ["doc/c1", "doc/c2"]
mock_doc_embeddings_client.query.return_value = test_docs mock_doc_embeddings_client.query.return_value = test_chunk_ids
# Initialize Query # Initialize Query
query = Query( query = Query(
@ -169,24 +201,25 @@ class TestQuery:
collection="test_collection" collection="test_collection"
) )
# Verify result is list of documents # Verify result is list of fetched document content
assert result == test_docs assert "Document 1 content" in result
assert "Document 2 content" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_method(self): async def test_document_rag_query_method(self, mock_fetch_chunk):
"""Test DocumentRag.query method orchestrates full document RAG pipeline""" """Test DocumentRag.query method orchestrates full document RAG pipeline"""
# Create mock clients # Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock embeddings and document responses # Mock embeddings and document embeddings responses
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
test_docs = ["Relevant document content", "Another document"] test_chunk_ids = ["doc/c3", "doc/c4"]
expected_response = "This is the document RAG response" expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = test_vectors mock_embeddings_client.embed.return_value = test_vectors
mock_doc_embeddings_client.query.return_value = test_docs mock_doc_embeddings_client.query.return_value = test_chunk_ids
mock_prompt_client.document_prompt.return_value = expected_response mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag # Initialize DocumentRag
@ -194,6 +227,7 @@ class TestQuery:
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )
@ -216,17 +250,20 @@ class TestQuery:
collection="test_collection" collection="test_collection"
) )
# Verify prompt client was called with documents and query # Verify prompt client was called with fetched documents and query
mock_prompt_client.document_prompt.assert_called_once_with( mock_prompt_client.document_prompt.assert_called_once()
query="test query", call_args = mock_prompt_client.document_prompt.call_args
documents=test_docs assert call_args.kwargs["query"] == "test query"
) # Documents should be fetched content, not chunk_ids
docs = call_args.kwargs["documents"]
assert "Relevant document content" in docs
assert "Another document" in docs
# Verify result # Verify result
assert result == expected_response assert result == expected_response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self): async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag.query method with default parameters""" """Test DocumentRag.query method with default parameters"""
# Create mock clients # Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
@ -235,14 +272,15 @@ class TestQuery:
# Mock responses # Mock responses
mock_embeddings_client.embed.return_value = [[0.1, 0.2]] mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_doc_embeddings_client.query.return_value = ["Default doc"] mock_doc_embeddings_client.query.return_value = ["doc/c5"]
mock_prompt_client.document_prompt.return_value = "Default response" mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag # Initialize DocumentRag
document_rag = DocumentRag( document_rag = DocumentRag(
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk
) )
# Call DocumentRag.query with minimal parameters # Call DocumentRag.query with minimal parameters
@ -268,9 +306,14 @@ class TestQuery:
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock responses # Mock responses
mock_embeddings_client.embed.return_value = [[0.7, 0.8]] mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
mock_doc_embeddings_client.query.return_value = ["Verbose test doc"] mock_doc_embeddings_client.query.return_value = ["doc/c6"]
# Initialize Query with verbose=True # Initialize Query with verbose=True
query = Query( query = Query(
@ -288,11 +331,11 @@ class TestQuery:
mock_embeddings_client.embed.assert_called_once_with("verbose test") mock_embeddings_client.embed.assert_called_once_with("verbose test")
mock_doc_embeddings_client.query.assert_called_once() mock_doc_embeddings_client.query.assert_called_once()
# Verify result # Verify result contains fetched content
assert result == ["Verbose test doc"] assert "Verbose test doc" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_with_verbose(self): async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag.query method with verbose logging enabled""" """Test DocumentRag.query method with verbose logging enabled"""
# Create mock clients # Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
@ -301,7 +344,7 @@ class TestQuery:
# Mock responses # Mock responses
mock_embeddings_client.embed.return_value = [[0.3, 0.4]] mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
mock_doc_embeddings_client.query.return_value = ["Verbose doc content"] mock_doc_embeddings_client.query.return_value = ["doc/c7"]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response" mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True # Initialize DocumentRag with verbose=True
@ -309,6 +352,7 @@ class TestQuery:
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=True verbose=True
) )
@ -318,10 +362,11 @@ class TestQuery:
# Verify all clients were called # Verify all clients were called
mock_embeddings_client.embed.assert_called_once_with("verbose query test") mock_embeddings_client.embed.assert_called_once_with("verbose query test")
mock_doc_embeddings_client.query.assert_called_once() mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once_with(
query="verbose query test", # Verify prompt client was called with fetched content
documents=["Verbose doc content"] call_args = mock_prompt_client.document_prompt.call_args
) assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response" assert result == "Verbose RAG response"
@ -335,9 +380,14 @@ class TestQuery:
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock responses - empty document list # Mock fetch_chunk (won't be called if no chunk_ids)
async def mock_fetch(chunk_id, user):
return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty chunk_id list
mock_embeddings_client.embed.return_value = [[0.1, 0.2]] mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_doc_embeddings_client.query.return_value = [] # No documents found mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
# Initialize Query # Initialize Query
query = Query( query = Query(
@ -358,16 +408,16 @@ class TestQuery:
assert result == [] assert result == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_query_with_empty_documents(self): async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
"""Test DocumentRag.query method when no documents are retrieved""" """Test DocumentRag.query method when no documents are retrieved"""
# Create mock clients # Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock() mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock() mock_doc_embeddings_client = AsyncMock()
# Mock responses - no documents found # Mock responses - no chunk_ids found
mock_embeddings_client.embed.return_value = [[0.5, 0.6]] mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
mock_doc_embeddings_client.query.return_value = [] # Empty document list mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
mock_prompt_client.document_prompt.return_value = "No documents found response" mock_prompt_client.document_prompt.return_value = "No documents found response"
# Initialize DocumentRag # Initialize DocumentRag
@ -375,6 +425,7 @@ class TestQuery:
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )
@ -419,7 +470,7 @@ class TestQuery:
assert result == expected_vectors assert result == expected_vectors
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_document_rag_integration_flow(self): async def test_document_rag_integration_flow(self, mock_fetch_chunk):
"""Test complete DocumentRag integration with realistic data flow""" """Test complete DocumentRag integration with realistic data flow"""
# Create mock clients # Create mock clients
mock_prompt_client = AsyncMock() mock_prompt_client = AsyncMock()
@ -429,15 +480,11 @@ class TestQuery:
# Mock realistic responses # Mock realistic responses
query_text = "What is machine learning?" query_text = "What is machine learning?"
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
retrieved_docs = [ retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
"Machine learning is a subset of artificial intelligence...",
"ML algorithms learn patterns from data to make predictions...",
"Common ML techniques include supervised and unsupervised learning..."
]
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = query_vectors mock_embeddings_client.embed.return_value = query_vectors
mock_doc_embeddings_client.query.return_value = retrieved_docs mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
mock_prompt_client.document_prompt.return_value = final_response mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag # Initialize DocumentRag
@ -445,6 +492,7 @@ class TestQuery:
prompt_client=mock_prompt_client, prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client, embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client, doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False verbose=False
) )
@ -466,10 +514,16 @@ class TestQuery:
collection="ml_knowledge" collection="ml_knowledge"
) )
mock_prompt_client.document_prompt.assert_called_once_with( # Verify prompt client was called with fetched document content
query=query_text, mock_prompt_client.document_prompt.assert_called_once()
documents=retrieved_docs call_args = mock_prompt_client.document_prompt.call_args
) assert call_args.kwargs["query"] == query_text
# Verify documents were fetched from chunk_ids
docs = call_args.kwargs["documents"]
assert "Machine learning is a subset of artificial intelligence..." in docs
assert "ML algorithms learn patterns from data to make predictions..." in docs
assert "Common ML techniques include supervised and unsupervised learning..." in docs
# Verify final result # Verify final result
assert result == final_response assert result == final_response

View file

@ -22,11 +22,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Create test document embeddings # Create test document embeddings
chunk1 = ChunkEmbeddings( chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk", chunk_id="This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
) )
chunk2 = ChunkEmbeddings( chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk", chunk_id="This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]] vectors=[[0.7, 0.8, 0.9]]
) )
message.chunks = [chunk1, chunk2] message.chunks = [chunk1, chunk2]
@ -84,7 +84,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Test document content", chunk_id="Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -136,7 +136,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"", chunk_id="",
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -148,51 +148,62 @@ class TestMilvusDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_document_embeddings_none_chunk(self, processor): async def test_store_document_embeddings_none_chunk(self, processor):
"""Test storing document embeddings with None chunk (should be skipped)""" """Test storing document embeddings with None chunk_id"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=None, chunk_id=None,
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings(message)
# Verify no insert was called for None chunk # Note: Implementation passes through None chunk_ids (only skips empty string "")
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], None, 'test_user', 'test_collection'
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor): async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
"""Test storing document embeddings with mix of valid and invalid chunks""" """Test storing document embeddings with mix of valid and empty chunks"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings( valid_chunk = ChunkEmbeddings(
chunk=b"Valid document content", chunk_id="Valid document content",
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
empty_chunk = ChunkEmbeddings( empty_chunk = ChunkEmbeddings(
chunk=b"", chunk_id="",
vectors=[[0.4, 0.5, 0.6]] vectors=[[0.4, 0.5, 0.6]]
) )
none_chunk = ChunkEmbeddings( another_valid = ChunkEmbeddings(
chunk=None, chunk_id="Another valid chunk",
vectors=[[0.7, 0.8, 0.9]] vectors=[[0.7, 0.8, 0.9]]
) )
message.chunks = [valid_chunk, empty_chunk, none_chunk] message.chunks = [valid_chunk, empty_chunk, another_valid]
await processor.store_document_embeddings(message) await processor.store_document_embeddings(message)
# Verify only valid chunk was inserted with user/collection parameters # Verify valid chunks were inserted, empty string chunk was skipped
processor.vecstore.insert.assert_called_once_with( expected_calls = [
[0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection' ([0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection'),
) ([0.7, 0.8, 0.9], "Another valid chunk", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_chunk_id, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_chunk_id
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor): async def test_store_document_embeddings_empty_chunks_list(self, processor):
@ -217,7 +228,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Document with no vectors", chunk_id="Document with no vectors",
vectors=[] vectors=[]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -236,7 +247,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions", chunk_id="Document with mixed dimensions",
vectors=[ vectors=[
[0.1, 0.2], # 2D vector [0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector [0.3, 0.4, 0.5, 0.6], # 4D vector
@ -264,46 +275,46 @@ class TestMilvusDocEmbeddingsStorageProcessor:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor): async def test_store_document_embeddings_unicode_content(self, processor):
"""Test storing document embeddings with Unicode content""" """Test storing document embeddings with Unicode content in chunk_id"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), chunk_id="chunk/doc/unicode-éñ中文🚀",
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings(message)
# Verify Unicode content was properly decoded and inserted with user/collection parameters # Verify Unicode chunk_id was stored correctly with user/collection parameters
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀", 'test_user', 'test_collection' [0.1, 0.2, 0.3], "chunk/doc/unicode-éñ中文🚀", 'test_user', 'test_collection'
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_document_embeddings_large_chunks(self, processor): async def test_store_document_embeddings_large_chunk_id(self, processor):
"""Test storing document embeddings with large document chunks""" """Test storing document embeddings with long chunk_id"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create a large document chunk # Create a long chunk_id
large_content = "A" * 10000 # 10KB of content long_chunk_id = "chunk/doc/" + "a" * 200
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'), chunk_id=long_chunk_id,
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings(message)
# Verify large content was inserted with user/collection parameters # Verify long chunk_id was inserted with user/collection parameters
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], large_content, 'test_user', 'test_collection' [0.1, 0.2, 0.3], long_chunk_id, 'test_user', 'test_collection'
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -315,7 +326,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b" \n\t ", chunk_id=" \n\t ",
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -346,7 +357,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = collection message.metadata.collection = collection
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Test content", chunk_id="Test content",
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -367,7 +378,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message1.metadata.user = 'user1' message1.metadata.user = 'user1'
message1.metadata.collection = 'collection1' message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings( chunk1 = ChunkEmbeddings(
chunk=b"User1 content", chunk_id="User1 content",
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message1.chunks = [chunk1] message1.chunks = [chunk1]
@ -378,7 +389,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message2.metadata.user = 'user2' message2.metadata.user = 'user2'
message2.metadata.collection = 'collection2' message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings( chunk2 = ChunkEmbeddings(
chunk=b"User2 content", chunk_id="User2 content",
vectors=[[0.4, 0.5, 0.6]] vectors=[[0.4, 0.5, 0.6]]
) )
message2.chunks = [chunk2] message2.chunks = [chunk2]
@ -409,7 +420,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata.collection = 'test-collection.v1' # Collection with special chars message.metadata.collection = 'test-collection.v1' # Collection with special chars
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Special chars test", chunk_id="Special chars test",
vectors=[[0.1, 0.2, 0.3]] vectors=[[0.1, 0.2, 0.3]]
) )
message.chunks = [chunk] message.chunks = [chunk]

View file

@ -88,7 +88,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'test_collection' mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test document chunk' mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
@ -111,7 +111,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
point = upsert_call_args[1]['points'][0] point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3] assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['doc'] == 'test document chunk' assert point.payload['chunk_id'] == 'doc/c1'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -142,11 +142,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'multi_collection' mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first document chunk' mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2]] mock_chunk1.vectors = [[0.1, 0.2]]
mock_chunk2 = MagicMock() mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second document chunk' mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.3, 0.4]] mock_chunk2.vectors = [[0.3, 0.4]]
mock_message.chunks = [mock_chunk1, mock_chunk2] mock_message.chunks = [mock_chunk1, mock_chunk2]
@ -165,13 +165,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
first_call = upsert_calls[0] first_call = upsert_calls[0]
first_point = first_call[1]['points'][0] first_point = first_call[1]['points'][0]
assert first_point.vector == [0.1, 0.2] assert first_point.vector == [0.1, 0.2]
assert first_point.payload['doc'] == 'first document chunk' assert first_point.payload['chunk_id'] == 'doc/c1'
# Second chunk # Second chunk
second_call = upsert_calls[1] second_call = upsert_calls[1]
second_point = second_call[1]['points'][0] second_point = second_call[1]['points'][0]
assert second_point.vector == [0.3, 0.4] assert second_point.vector == [0.3, 0.4]
assert second_point.payload['doc'] == 'second document chunk' assert second_point.payload['chunk_id'] == 'doc/c2'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -202,7 +202,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'vector_collection' mock_message.metadata.collection = 'vector_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk' mock_chunk.chunk_id = 'doc/multi-vector'
mock_chunk.vectors = [ mock_chunk.vectors = [
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
[0.4, 0.5, 0.6], [0.4, 0.5, 0.6],
@ -230,11 +230,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
for i, call in enumerate(upsert_calls): for i, call in enumerate(upsert_calls):
point = call[1]['points'][0] point = call[1]['points'][0]
assert point.vector == expected_vectors[i] assert point.vector == expected_vectors[i]
assert point.payload['doc'] == 'multi-vector document chunk' assert point.payload['chunk_id'] == 'doc/multi-vector'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client): async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
"""Test storing document embeddings skips empty chunks""" """Test storing document embeddings skips empty chunk_ids"""
# Arrange # Arrange
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists mock_qdrant_instance.collection_exists.return_value = True # Collection exists
@ -249,13 +249,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Create mock message with empty chunk # Create mock message with empty chunk_id
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'empty_user' mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection' mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock() mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk.decode.return_value = "" # Empty string mock_chunk_empty.chunk_id = "" # Empty chunk_id
mock_chunk_empty.vectors = [[0.1, 0.2]] mock_chunk_empty.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk_empty] mock_message.chunks = [mock_chunk_empty]
@ -264,9 +264,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings(mock_message)
# Assert # Assert
# Should not call upsert for empty chunks # Should not call upsert for empty chunk_ids
mock_qdrant_instance.upsert.assert_not_called() mock_qdrant_instance.upsert.assert_not_called()
# collection_exists should NOT be called since we return early for empty chunks # collection_exists should NOT be called since we return early for empty chunk_ids
mock_qdrant_instance.collection_exists.assert_not_called() mock_qdrant_instance.collection_exists.assert_not_called()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@ -298,7 +298,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'new_collection' mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk' mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
@ -350,7 +350,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'error_collection' mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk' mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2]] mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
@ -388,7 +388,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message1.metadata.collection = 'cache_collection' mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first chunk' mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]] mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
mock_message1.chunks = [mock_chunk1] mock_message1.chunks = [mock_chunk1]
@ -406,7 +406,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message2.metadata.collection = 'cache_collection' mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock() mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second chunk' mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3) mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
mock_message2.chunks = [mock_chunk2] mock_message2.chunks = [mock_chunk2]
@ -452,7 +452,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'dim_collection' mock_message.metadata.collection = 'dim_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'dimension test chunk' mock_chunk.chunk_id = 'doc/dim-test'
mock_chunk.vectors = [ mock_chunk.vectors = [
[0.1, 0.2], # 2 dimensions [0.1, 0.2], # 2 dimensions
[0.3, 0.4, 0.5] # 3 dimensions [0.3, 0.4, 0.5] # 3 dimensions
@ -498,8 +498,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client): async def test_chunk_id_with_special_characters(self, mock_uuid, mock_qdrant_client):
"""Test proper UTF-8 decoding of chunk text""" """Test storing chunk_id with special characters (URIs)"""
# Arrange # Arrange
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
@ -517,15 +517,15 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push) # Add collection to known_collections (simulates config push)
processor.known_collections[('utf8_user', 'utf8_collection')] = {} processor.known_collections[('uri_user', 'uri_collection')] = {}
# Create mock message with UTF-8 encoded text # Create mock message with URI-style chunk_id
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'utf8_user' mock_message.metadata.user = 'uri_user'
mock_message.metadata.collection = 'utf8_collection' mock_message.metadata.collection = 'uri_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé' mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
mock_chunk.vectors = [[0.1, 0.2]] mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
@ -534,47 +534,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings(mock_message)
# Assert # Assert
# Verify chunk.decode was called with 'utf-8' # Verify the chunk_id was stored correctly
mock_chunk.chunk.decode.assert_called_with('utf-8')
# Verify the decoded text was stored in payload
upsert_call_args = mock_qdrant_instance.upsert.call_args upsert_call_args = mock_qdrant_instance.upsert.call_args
point = upsert_call_args[1]['points'][0] point = upsert_call_args[1]['points'][0]
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé' assert point.payload['chunk_id'] == 'https://trustgraph.ai/doc/my-document/p1/c3'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
async def test_chunk_decode_exception_handling(self, mock_qdrant_client):
"""Test handling of chunk decode exceptions"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-doc-qdrant-processor'
}
processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('decode_user', 'decode_collection')] = {}
# Create mock message with decode error
mock_message = MagicMock()
mock_message.metadata.user = 'decode_user'
mock_message.metadata.collection = 'decode_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte')
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(UnicodeDecodeError):
await processor.store_document_embeddings(mock_message)
if __name__ == '__main__': if __name__ == '__main__':