mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix tests (#666)
This commit is contained in:
parent
24bbe94136
commit
3bf8a65409
10 changed files with 510 additions and 446 deletions
|
|
@ -25,13 +25,13 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify all expected fields exist
|
# Verify all expected fields exist
|
||||||
assert hasattr(request, 'vectors')
|
assert hasattr(request, 'vectors')
|
||||||
assert hasattr(request, 'limit')
|
assert hasattr(request, 'limit')
|
||||||
assert hasattr(request, 'user')
|
assert hasattr(request, 'user')
|
||||||
assert hasattr(request, 'collection')
|
assert hasattr(request, 'collection')
|
||||||
|
|
||||||
# Verify field values
|
# Verify field values
|
||||||
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
assert request.limit == 10
|
assert request.limit == 10
|
||||||
|
|
@ -41,16 +41,16 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
def test_request_translator_to_pulsar(self):
|
def test_request_translator_to_pulsar(self):
|
||||||
"""Test request translator converts dict to Pulsar schema"""
|
"""Test request translator converts dict to Pulsar schema"""
|
||||||
translator = DocumentEmbeddingsRequestTranslator()
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"vectors": [[0.1, 0.2], [0.3, 0.4]],
|
"vectors": [[0.1, 0.2], [0.3, 0.4]],
|
||||||
"limit": 5,
|
"limit": 5,
|
||||||
"user": "custom_user",
|
"user": "custom_user",
|
||||||
"collection": "custom_collection"
|
"collection": "custom_collection"
|
||||||
}
|
}
|
||||||
|
|
||||||
result = translator.to_pulsar(data)
|
result = translator.to_pulsar(data)
|
||||||
|
|
||||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||||
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
|
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
|
||||||
assert result.limit == 5
|
assert result.limit == 5
|
||||||
|
|
@ -60,14 +60,14 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
def test_request_translator_to_pulsar_with_defaults(self):
|
def test_request_translator_to_pulsar_with_defaults(self):
|
||||||
"""Test request translator uses correct defaults"""
|
"""Test request translator uses correct defaults"""
|
||||||
translator = DocumentEmbeddingsRequestTranslator()
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"vectors": [[0.1, 0.2]]
|
"vectors": [[0.1, 0.2]]
|
||||||
# No limit, user, or collection provided
|
# No limit, user, or collection provided
|
||||||
}
|
}
|
||||||
|
|
||||||
result = translator.to_pulsar(data)
|
result = translator.to_pulsar(data)
|
||||||
|
|
||||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||||
assert result.vectors == [[0.1, 0.2]]
|
assert result.vectors == [[0.1, 0.2]]
|
||||||
assert result.limit == 10 # Default
|
assert result.limit == 10 # Default
|
||||||
|
|
@ -77,16 +77,16 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
def test_request_translator_from_pulsar(self):
|
def test_request_translator_from_pulsar(self):
|
||||||
"""Test request translator converts Pulsar schema to dict"""
|
"""Test request translator converts Pulsar schema to dict"""
|
||||||
translator = DocumentEmbeddingsRequestTranslator()
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
request = DocumentEmbeddingsRequest(
|
request = DocumentEmbeddingsRequest(
|
||||||
vectors=[[0.5, 0.6]],
|
vectors=[[0.5, 0.6]],
|
||||||
limit=20,
|
limit=20,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
result = translator.from_pulsar(request)
|
result = translator.from_pulsar(request)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert result["vectors"] == [[0.5, 0.6]]
|
assert result["vectors"] == [[0.5, 0.6]]
|
||||||
assert result["limit"] == 20
|
assert result["limit"] == 20
|
||||||
|
|
@ -99,19 +99,19 @@ class TestDocumentEmbeddingsResponseContract:
|
||||||
|
|
||||||
def test_response_schema_fields(self):
|
def test_response_schema_fields(self):
|
||||||
"""Test that DocumentEmbeddingsResponse has expected fields"""
|
"""Test that DocumentEmbeddingsResponse has expected fields"""
|
||||||
# Create a response with chunks
|
# Create a response with chunk_ids
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunks=["chunk1", "chunk2", "chunk3"]
|
chunk_ids=["chunk1", "chunk2", "chunk3"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify all expected fields exist
|
# Verify all expected fields exist
|
||||||
assert hasattr(response, 'error')
|
assert hasattr(response, 'error')
|
||||||
assert hasattr(response, 'chunks')
|
assert hasattr(response, 'chunk_ids')
|
||||||
|
|
||||||
# Verify field values
|
# Verify field values
|
||||||
assert response.error is None
|
assert response.error is None
|
||||||
assert response.chunks == ["chunk1", "chunk2", "chunk3"]
|
assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
def test_response_schema_with_error(self):
|
def test_response_schema_with_error(self):
|
||||||
"""Test response schema with error"""
|
"""Test response schema with error"""
|
||||||
|
|
@ -119,90 +119,79 @@ class TestDocumentEmbeddingsResponseContract:
|
||||||
type="query_error",
|
type="query_error",
|
||||||
message="Database connection failed"
|
message="Database connection failed"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=error,
|
error=error,
|
||||||
chunks=None
|
chunk_ids=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.error == error
|
|
||||||
assert response.chunks is None
|
|
||||||
|
|
||||||
def test_response_translator_from_pulsar_with_chunks(self):
|
assert response.error == error
|
||||||
"""Test response translator converts Pulsar schema with chunks to dict"""
|
assert response.chunk_ids == []
|
||||||
|
|
||||||
|
def test_response_translator_from_pulsar_with_chunk_ids(self):
|
||||||
|
"""Test response translator converts Pulsar schema with chunk_ids to dict"""
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunks=["doc1", "doc2", "doc3"]
|
chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"]
|
||||||
)
|
)
|
||||||
|
|
||||||
result = translator.from_pulsar(response)
|
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "chunks" in result
|
|
||||||
assert result["chunks"] == ["doc1", "doc2", "doc3"]
|
|
||||||
|
|
||||||
def test_response_translator_from_pulsar_with_bytes(self):
|
|
||||||
"""Test response translator handles byte chunks correctly"""
|
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
|
||||||
|
|
||||||
response = MagicMock()
|
|
||||||
response.chunks = [b"byte_chunk1", b"byte_chunk2"]
|
|
||||||
|
|
||||||
result = translator.from_pulsar(response)
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "chunks" in result
|
|
||||||
assert result["chunks"] == ["byte_chunk1", "byte_chunk2"]
|
|
||||||
|
|
||||||
def test_response_translator_from_pulsar_with_empty_chunks(self):
|
|
||||||
"""Test response translator handles empty chunks list"""
|
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
|
||||||
|
|
||||||
response = MagicMock()
|
|
||||||
response.chunks = []
|
|
||||||
|
|
||||||
result = translator.from_pulsar(response)
|
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert "chunks" in result
|
assert "chunk_ids" in result
|
||||||
assert result["chunks"] == []
|
assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"]
|
||||||
|
|
||||||
def test_response_translator_from_pulsar_with_none_chunks(self):
|
def test_response_translator_from_pulsar_with_empty_chunk_ids(self):
|
||||||
"""Test response translator handles None chunks"""
|
"""Test response translator handles empty chunk_ids list"""
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
response = MagicMock()
|
response = DocumentEmbeddingsResponse(
|
||||||
response.chunks = None
|
error=None,
|
||||||
|
chunk_ids=[]
|
||||||
|
)
|
||||||
|
|
||||||
result = translator.from_pulsar(response)
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert "chunks" not in result or result.get("chunks") is None
|
assert "chunk_ids" in result
|
||||||
|
assert result["chunk_ids"] == []
|
||||||
|
|
||||||
|
def test_response_translator_from_pulsar_with_none_chunk_ids(self):
|
||||||
|
"""Test response translator handles None chunk_ids"""
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.chunk_ids = None
|
||||||
|
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "chunk_ids" not in result or result.get("chunk_ids") is None
|
||||||
|
|
||||||
def test_response_translator_from_response_with_completion(self):
|
def test_response_translator_from_response_with_completion(self):
|
||||||
"""Test response translator with completion flag"""
|
"""Test response translator with completion flag"""
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunks=["chunk1", "chunk2"]
|
chunk_ids=["chunk1", "chunk2"]
|
||||||
)
|
)
|
||||||
|
|
||||||
result, is_final = translator.from_response_with_completion(response)
|
result, is_final = translator.from_response_with_completion(response)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert "chunks" in result
|
assert "chunk_ids" in result
|
||||||
assert result["chunks"] == ["chunk1", "chunk2"]
|
assert result["chunk_ids"] == ["chunk1", "chunk2"]
|
||||||
assert is_final is True # Document embeddings responses are always final
|
assert is_final is True # Document embeddings responses are always final
|
||||||
|
|
||||||
def test_response_translator_to_pulsar_not_implemented(self):
|
def test_response_translator_to_pulsar_not_implemented(self):
|
||||||
"""Test that to_pulsar raises NotImplementedError for responses"""
|
"""Test that to_pulsar raises NotImplementedError for responses"""
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
translator.to_pulsar({"chunks": ["test"]})
|
translator.to_pulsar({"chunk_ids": ["test"]})
|
||||||
|
|
||||||
|
|
||||||
class TestDocumentEmbeddingsMessageCompatibility:
|
class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
|
|
@ -217,26 +206,26 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
"user": "test_user",
|
"user": "test_user",
|
||||||
"collection": "test_collection"
|
"collection": "test_collection"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Convert to Pulsar request
|
# Convert to Pulsar request
|
||||||
req_translator = DocumentEmbeddingsRequestTranslator()
|
req_translator = DocumentEmbeddingsRequestTranslator()
|
||||||
pulsar_request = req_translator.to_pulsar(request_data)
|
pulsar_request = req_translator.to_pulsar(request_data)
|
||||||
|
|
||||||
# Simulate service processing and creating response
|
# Simulate service processing and creating response
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunks=["relevant chunk 1", "relevant chunk 2"]
|
chunk_ids=["doc1/c1", "doc2/c2"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert response back to dict
|
# Convert response back to dict
|
||||||
resp_translator = DocumentEmbeddingsResponseTranslator()
|
resp_translator = DocumentEmbeddingsResponseTranslator()
|
||||||
response_data = resp_translator.from_pulsar(response)
|
response_data = resp_translator.from_pulsar(response)
|
||||||
|
|
||||||
# Verify data integrity
|
# Verify data integrity
|
||||||
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
||||||
assert isinstance(response_data, dict)
|
assert isinstance(response_data, dict)
|
||||||
assert "chunks" in response_data
|
assert "chunk_ids" in response_data
|
||||||
assert len(response_data["chunks"]) == 2
|
assert len(response_data["chunk_ids"]) == 2
|
||||||
|
|
||||||
def test_error_response_flow(self):
|
def test_error_response_flow(self):
|
||||||
"""Test error response flow"""
|
"""Test error response flow"""
|
||||||
|
|
@ -245,17 +234,18 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
type="vector_db_error",
|
type="vector_db_error",
|
||||||
message="Collection not found"
|
message="Collection not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=error,
|
error=error,
|
||||||
chunks=None
|
chunk_ids=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert response to dict
|
# Convert response to dict
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
response_data = translator.from_pulsar(response)
|
response_data = translator.from_pulsar(response)
|
||||||
|
|
||||||
# Verify error handling
|
# Verify error handling
|
||||||
assert isinstance(response_data, dict)
|
assert isinstance(response_data, dict)
|
||||||
# The translator doesn't include error in the dict, only chunks
|
# The translator doesn't include error in the dict, only chunk_ids
|
||||||
assert "chunks" not in response_data or response_data.get("chunks") is None
|
assert "chunk_ids" in response_data
|
||||||
|
assert response_data["chunk_ids"] == []
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,14 @@ from unittest.mock import AsyncMock, MagicMock
|
||||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
|
||||||
|
|
||||||
|
# Sample chunk content for testing - maps chunk_id to content
|
||||||
|
CHUNK_CONTENT = {
|
||||||
|
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||||
|
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||||
|
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
class TestDocumentRagIntegration:
|
class TestDocumentRagIntegration:
|
||||||
"""Integration tests for DocumentRAG system coordination"""
|
"""Integration tests for DocumentRAG system coordination"""
|
||||||
|
|
@ -27,15 +35,19 @@ class TestDocumentRagIntegration:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_doc_embeddings_client(self):
|
def mock_doc_embeddings_client(self):
|
||||||
"""Mock document embeddings client that returns realistic document chunks"""
|
"""Mock document embeddings client that returns chunk IDs"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.query.return_value = [
|
# Now returns chunk_ids instead of actual content
|
||||||
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
|
||||||
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
|
||||||
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
|
||||||
]
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_fetch_chunk(self):
|
||||||
|
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||||
|
async def fetch(chunk_id, user):
|
||||||
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
|
return fetch
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_prompt_client(self):
|
def mock_prompt_client(self):
|
||||||
"""Mock prompt client that generates realistic responses"""
|
"""Mock prompt client that generates realistic responses"""
|
||||||
|
|
@ -48,17 +60,19 @@ class TestDocumentRagIntegration:
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
|
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||||
|
mock_prompt_client, mock_fetch_chunk):
|
||||||
"""Create DocumentRag instance with mocked dependencies"""
|
"""Create DocumentRag instance with mocked dependencies"""
|
||||||
return DocumentRag(
|
return DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
||||||
mock_doc_embeddings_client, mock_prompt_client):
|
mock_doc_embeddings_client, mock_prompt_client):
|
||||||
"""Test complete DocumentRAG pipeline from query to response"""
|
"""Test complete DocumentRAG pipeline from query to response"""
|
||||||
# Arrange
|
# Arrange
|
||||||
|
|
@ -77,14 +91,15 @@ class TestDocumentRagIntegration:
|
||||||
|
|
||||||
# Assert - Verify service coordination
|
# Assert - Verify service coordination
|
||||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
mock_embeddings_client.embed.assert_called_once_with(query)
|
||||||
|
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||||
limit=doc_limit,
|
limit=doc_limit,
|
||||||
user=user,
|
user=user,
|
||||||
collection=collection
|
collection=collection
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Documents are fetched from librarian using chunk_ids
|
||||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||||
query=query,
|
query=query,
|
||||||
documents=[
|
documents=[
|
||||||
|
|
@ -101,17 +116,19 @@ class TestDocumentRagIntegration:
|
||||||
assert "artificial intelligence" in result.lower()
|
assert "artificial intelligence" in result.lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
||||||
mock_doc_embeddings_client, mock_prompt_client):
|
mock_doc_embeddings_client, mock_prompt_client,
|
||||||
|
mock_fetch_chunk):
|
||||||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||||
|
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -125,92 +142,98 @@ class TestDocumentRagIntegration:
|
||||||
query="very obscure query",
|
query="very obscure query",
|
||||||
documents=[]
|
documents=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == "I couldn't find any relevant documents for your query."
|
assert result == "I couldn't find any relevant documents for your query."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||||
mock_doc_embeddings_client, mock_prompt_client):
|
mock_doc_embeddings_client, mock_prompt_client,
|
||||||
|
mock_fetch_chunk):
|
||||||
"""Test DocumentRAG error handling when embeddings service fails"""
|
"""Test DocumentRAG error handling when embeddings service fails"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
|
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
|
||||||
|
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
await document_rag.query("test query")
|
await document_rag.query("test query")
|
||||||
|
|
||||||
assert "Embeddings service unavailable" in str(exc_info.value)
|
assert "Embeddings service unavailable" in str(exc_info.value)
|
||||||
mock_embeddings_client.embed.assert_called_once()
|
mock_embeddings_client.embed.assert_called_once()
|
||||||
mock_doc_embeddings_client.query.assert_not_called()
|
mock_doc_embeddings_client.query.assert_not_called()
|
||||||
mock_prompt_client.document_prompt.assert_not_called()
|
mock_prompt_client.document_prompt.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
||||||
mock_doc_embeddings_client, mock_prompt_client):
|
mock_doc_embeddings_client, mock_prompt_client,
|
||||||
|
mock_fetch_chunk):
|
||||||
"""Test DocumentRAG error handling when document service fails"""
|
"""Test DocumentRAG error handling when document service fails"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
|
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
|
||||||
|
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
await document_rag.query("test query")
|
await document_rag.query("test query")
|
||||||
|
|
||||||
assert "Document service connection failed" in str(exc_info.value)
|
assert "Document service connection failed" in str(exc_info.value)
|
||||||
mock_embeddings_client.embed.assert_called_once()
|
mock_embeddings_client.embed.assert_called_once()
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
mock_prompt_client.document_prompt.assert_not_called()
|
mock_prompt_client.document_prompt.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
||||||
mock_doc_embeddings_client, mock_prompt_client):
|
mock_doc_embeddings_client, mock_prompt_client,
|
||||||
|
mock_fetch_chunk):
|
||||||
"""Test DocumentRAG error handling when prompt service fails"""
|
"""Test DocumentRAG error handling when prompt service fails"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
|
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
|
||||||
|
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
await document_rag.query("test query")
|
await document_rag.query("test query")
|
||||||
|
|
||||||
assert "LLM service rate limited" in str(exc_info.value)
|
assert "LLM service rate limited" in str(exc_info.value)
|
||||||
mock_embeddings_client.embed.assert_called_once()
|
mock_embeddings_client.embed.assert_called_once()
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
mock_prompt_client.document_prompt.assert_called_once()
|
mock_prompt_client.document_prompt.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_with_different_document_limits(self, document_rag,
|
async def test_document_rag_with_different_document_limits(self, document_rag,
|
||||||
mock_doc_embeddings_client):
|
mock_doc_embeddings_client):
|
||||||
"""Test DocumentRAG with various document limit configurations"""
|
"""Test DocumentRAG with various document limit configurations"""
|
||||||
# Test different document limits
|
# Test different document limits
|
||||||
test_cases = [1, 5, 10, 25, 50]
|
test_cases = [1, 5, 10, 25, 50]
|
||||||
|
|
||||||
for limit in test_cases:
|
for limit in test_cases:
|
||||||
# Reset mock call history
|
# Reset mock call history
|
||||||
mock_doc_embeddings_client.reset_mock()
|
mock_doc_embeddings_client.reset_mock()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
|
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
call_args = mock_doc_embeddings_client.query.call_args
|
call_args = mock_doc_embeddings_client.query.call_args
|
||||||
|
|
@ -230,14 +253,14 @@ class TestDocumentRagIntegration:
|
||||||
for user, collection in test_scenarios:
|
for user, collection in test_scenarios:
|
||||||
# Reset mock call history
|
# Reset mock call history
|
||||||
mock_doc_embeddings_client.reset_mock()
|
mock_doc_embeddings_client.reset_mock()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await document_rag.query(
|
await document_rag.query(
|
||||||
f"query from {user} in {collection}",
|
f"query from {user} in {collection}",
|
||||||
user=user,
|
user=user,
|
||||||
collection=collection
|
collection=collection
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
call_args = mock_doc_embeddings_client.query.call_args
|
call_args = mock_doc_embeddings_client.query.call_args
|
||||||
|
|
@ -245,19 +268,21 @@ class TestDocumentRagIntegration:
|
||||||
assert call_args.kwargs['collection'] == collection
|
assert call_args.kwargs['collection'] == collection
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
||||||
mock_doc_embeddings_client, mock_prompt_client,
|
mock_doc_embeddings_client, mock_prompt_client,
|
||||||
|
mock_fetch_chunk,
|
||||||
caplog):
|
caplog):
|
||||||
"""Test DocumentRAG verbose logging functionality"""
|
"""Test DocumentRAG verbose logging functionality"""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# Arrange - Configure logging to capture debug messages
|
# Arrange - Configure logging to capture debug messages
|
||||||
caplog.set_level(logging.DEBUG)
|
caplog.set_level(logging.DEBUG)
|
||||||
|
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -269,25 +294,25 @@ class TestDocumentRagIntegration:
|
||||||
assert "DocumentRag initialized" in log_messages
|
assert "DocumentRag initialized" in log_messages
|
||||||
assert "Constructing prompt..." in log_messages
|
assert "Constructing prompt..." in log_messages
|
||||||
assert "Computing embeddings..." in log_messages
|
assert "Computing embeddings..." in log_messages
|
||||||
assert "Getting documents..." in log_messages
|
assert "chunk_ids" in log_messages.lower()
|
||||||
assert "Invoking LLM..." in log_messages
|
assert "Invoking LLM..." in log_messages
|
||||||
assert "Query processing complete" in log_messages
|
assert "Query processing complete" in log_messages
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||||
mock_doc_embeddings_client):
|
mock_doc_embeddings_client):
|
||||||
"""Test DocumentRAG performance with large document retrieval"""
|
"""Test DocumentRAG performance with large document retrieval"""
|
||||||
# Arrange - Mock large document set (100 documents)
|
# Arrange - Mock large chunk_id set (100 chunks)
|
||||||
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
|
large_chunk_ids = [f"doc/c{i}" for i in range(100)]
|
||||||
mock_doc_embeddings_client.query.return_value = large_doc_set
|
mock_doc_embeddings_client.query.return_value = large_chunk_ids
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
result = await document_rag.query("performance test query", doc_limit=100)
|
result = await document_rag.query("performance test query", doc_limit=100)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
execution_time = end_time - start_time
|
execution_time = end_time - start_time
|
||||||
|
|
||||||
|
|
@ -309,4 +334,4 @@ class TestDocumentRagIntegration:
|
||||||
call_args = mock_doc_embeddings_client.query.call_args
|
call_args = mock_doc_embeddings_client.query.call_args
|
||||||
assert call_args.kwargs['user'] == "trustgraph"
|
assert call_args.kwargs['user'] == "trustgraph"
|
||||||
assert call_args.kwargs['collection'] == "default"
|
assert call_args.kwargs['collection'] == "default"
|
||||||
assert call_args.kwargs['limit'] == 20
|
assert call_args.kwargs['limit'] == 20
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,14 @@ from tests.utils.streaming_assertions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Sample chunk content for testing - maps chunk_id to content
|
||||||
|
CHUNK_CONTENT = {
|
||||||
|
"doc/c1": "Machine learning is a subset of AI.",
|
||||||
|
"doc/c2": "Deep learning uses neural networks.",
|
||||||
|
"doc/c3": "Supervised learning needs labeled data.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
class TestDocumentRagStreaming:
|
class TestDocumentRagStreaming:
|
||||||
"""Integration tests for DocumentRAG streaming"""
|
"""Integration tests for DocumentRAG streaming"""
|
||||||
|
|
@ -27,15 +35,19 @@ class TestDocumentRagStreaming:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_doc_embeddings_client(self):
|
def mock_doc_embeddings_client(self):
|
||||||
"""Mock document embeddings client"""
|
"""Mock document embeddings client that returns chunk IDs"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.query.return_value = [
|
# Now returns chunk_ids instead of actual content
|
||||||
"Machine learning is a subset of AI.",
|
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
|
||||||
"Deep learning uses neural networks.",
|
|
||||||
"Supervised learning needs labeled data."
|
|
||||||
]
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_fetch_chunk(self):
|
||||||
|
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||||
|
async def fetch(chunk_id, user):
|
||||||
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
|
return fetch
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
||||||
"""Mock prompt client with streaming support"""
|
"""Mock prompt client with streaming support"""
|
||||||
|
|
@ -66,12 +78,13 @@ class TestDocumentRagStreaming:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
|
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||||
mock_streaming_prompt_client):
|
mock_streaming_prompt_client, mock_fetch_chunk):
|
||||||
"""Create DocumentRag instance with streaming support"""
|
"""Create DocumentRag instance with streaming support"""
|
||||||
return DocumentRag(
|
return DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_streaming_prompt_client,
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -190,7 +203,7 @@ class TestDocumentRagStreaming:
|
||||||
mock_doc_embeddings_client):
|
mock_doc_embeddings_client):
|
||||||
"""Test streaming with no documents found"""
|
"""Test streaming with no documents found"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_doc_embeddings_client.query.return_value = [] # No documents
|
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids
|
||||||
callback = AsyncMock()
|
callback = AsyncMock()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|
|
||||||
|
|
@ -202,11 +202,18 @@ class TestDocumentRagStreamingProtocol:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_doc_embeddings_client(self):
|
def mock_doc_embeddings_client(self):
|
||||||
"""Mock document embeddings client"""
|
"""Mock document embeddings client that returns chunk IDs"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.query.return_value = ["doc1", "doc2"]
|
client.query.return_value = ["doc/c1", "doc/c2"]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_fetch_chunk(self):
|
||||||
|
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||||
|
async def fetch(chunk_id, user):
|
||||||
|
return f"Content for {chunk_id}"
|
||||||
|
return fetch
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_streaming_prompt_client(self):
|
def mock_streaming_prompt_client(self):
|
||||||
"""Mock prompt client with streaming support"""
|
"""Mock prompt client with streaming support"""
|
||||||
|
|
@ -227,12 +234,13 @@ class TestDocumentRagStreamingProtocol:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||||
mock_streaming_prompt_client):
|
mock_streaming_prompt_client, mock_fetch_chunk):
|
||||||
"""Create DocumentRag instance with mocked dependencies"""
|
"""Create DocumentRag instance with mocked dependencies"""
|
||||||
return DocumentRag(
|
return DocumentRag(
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
prompt_client=mock_streaming_prompt_client,
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
# Mock the request method
|
# Mock the request method
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
@ -75,7 +75,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunks = []
|
mock_response.chunk_ids = []
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
|
@ -93,7 +93,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunks = ["test_chunk"]
|
mock_response.chunk_ids = ["test_chunk"]
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
|
@ -115,7 +115,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunks = ["chunk1"]
|
mock_response.chunk_ids = ["chunk1"]
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
|
@ -136,7 +136,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunks = ["test_chunk"]
|
mock_response.chunk_ids = ["test_chunk"]
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,9 +77,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
# Mock search results
|
# Mock search results
|
||||||
mock_results = [
|
mock_results = [
|
||||||
{"entity": {"doc": "First document chunk"}},
|
{"entity": {"chunk_id": "First document chunk"}},
|
||||||
{"entity": {"doc": "Second document chunk"}},
|
{"entity": {"chunk_id": "Second document chunk"}},
|
||||||
{"entity": {"doc": "Third document chunk"}},
|
{"entity": {"chunk_id": "Third document chunk"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.return_value = mock_results
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
|
|
@ -108,11 +108,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
# Mock search results - different results for each vector
|
# Mock search results - different results for each vector
|
||||||
mock_results_1 = [
|
mock_results_1 = [
|
||||||
{"entity": {"doc": "Document from first vector"}},
|
{"entity": {"chunk_id": "Document from first vector"}},
|
||||||
{"entity": {"doc": "Another doc from first vector"}},
|
{"entity": {"chunk_id": "Another doc from first vector"}},
|
||||||
]
|
]
|
||||||
mock_results_2 = [
|
mock_results_2 = [
|
||||||
{"entity": {"doc": "Document from second vector"}},
|
{"entity": {"chunk_id": "Document from second vector"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||||
|
|
||||||
|
|
@ -147,10 +147,10 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
# Mock search results - more results than limit
|
# Mock search results - more results than limit
|
||||||
mock_results = [
|
mock_results = [
|
||||||
{"entity": {"doc": "Document 1"}},
|
{"entity": {"chunk_id": "Document 1"}},
|
||||||
{"entity": {"doc": "Document 2"}},
|
{"entity": {"chunk_id": "Document 2"}},
|
||||||
{"entity": {"doc": "Document 3"}},
|
{"entity": {"chunk_id": "Document 3"}},
|
||||||
{"entity": {"doc": "Document 4"}},
|
{"entity": {"chunk_id": "Document 4"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.return_value = mock_results
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
|
|
@ -217,9 +217,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
# Mock search results with Unicode content
|
# Mock search results with Unicode content
|
||||||
mock_results = [
|
mock_results = [
|
||||||
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
|
{"entity": {"chunk_id": "Document with Unicode: éñ中文🚀"}},
|
||||||
{"entity": {"doc": "Regular ASCII document"}},
|
{"entity": {"chunk_id": "Regular ASCII document"}},
|
||||||
{"entity": {"doc": "Document with émojis: 😀🎉"}},
|
{"entity": {"chunk_id": "Document with émojis: 😀🎉"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.return_value = mock_results
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
|
|
@ -244,8 +244,8 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
# Mock search results with large content
|
# Mock search results with large content
|
||||||
large_doc = "A" * 10000 # 10KB of content
|
large_doc = "A" * 10000 # 10KB of content
|
||||||
mock_results = [
|
mock_results = [
|
||||||
{"entity": {"doc": large_doc}},
|
{"entity": {"chunk_id": large_doc}},
|
||||||
{"entity": {"doc": "Small document"}},
|
{"entity": {"chunk_id": "Small document"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.return_value = mock_results
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
|
|
@ -268,9 +268,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
# Mock search results with special characters
|
# Mock search results with special characters
|
||||||
mock_results = [
|
mock_results = [
|
||||||
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
|
{"entity": {"chunk_id": "Document with \"quotes\" and 'apostrophes'"}},
|
||||||
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
|
{"entity": {"chunk_id": "Document with\nnewlines\tand\ttabs"}},
|
||||||
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
|
{"entity": {"chunk_id": "Document with special chars: @#$%^&*()"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.return_value = mock_results
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
|
|
@ -350,9 +350,9 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock search results for each vector
|
# Mock search results for each vector
|
||||||
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
|
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}]
|
||||||
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
|
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}]
|
||||||
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
|
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}]
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||||
|
|
||||||
result = await processor.query_document_embeddings(query)
|
result = await processor.query_document_embeddings(query)
|
||||||
|
|
@ -378,12 +378,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
# Mock search results with duplicates across vectors
|
# Mock search results with duplicates across vectors
|
||||||
mock_results_1 = [
|
mock_results_1 = [
|
||||||
{"entity": {"doc": "Document A"}},
|
{"entity": {"chunk_id": "Document A"}},
|
||||||
{"entity": {"doc": "Document B"}},
|
{"entity": {"chunk_id": "Document B"}},
|
||||||
]
|
]
|
||||||
mock_results_2 = [
|
mock_results_2 = [
|
||||||
{"entity": {"doc": "Document B"}}, # Duplicate
|
{"entity": {"chunk_id": "Document B"}}, # Duplicate
|
||||||
{"entity": {"doc": "Document C"}},
|
{"entity": {"chunk_id": "Document C"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||||
|
|
||||||
|
|
@ -458,5 +458,5 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
mock_launch.assert_called_once_with(
|
mock_launch.assert_called_once_with(
|
||||||
default_ident,
|
default_ident,
|
||||||
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
|
"\nDocument embeddings query service. Input is vector, output is an array\nof chunk_ids\n"
|
||||||
)
|
)
|
||||||
|
|
@ -77,9 +77,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Mock query response
|
# Mock query response
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'doc': 'first document chunk'}
|
mock_point1.payload = {'chunk_id': 'first document chunk'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {'doc': 'second document chunk'}
|
mock_point2.payload = {'chunk_id': 'second document chunk'}
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.points = [mock_point1, mock_point2]
|
mock_response.points = [mock_point1, mock_point2]
|
||||||
|
|
@ -132,11 +132,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Mock query responses for different vectors
|
# Mock query responses for different vectors
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'doc': 'document from vector 1'}
|
mock_point1.payload = {'chunk_id': 'document from vector 1'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {'doc': 'document from vector 2'}
|
mock_point2.payload = {'chunk_id': 'document from vector 2'}
|
||||||
mock_point3 = MagicMock()
|
mock_point3 = MagicMock()
|
||||||
mock_point3.payload = {'doc': 'another document from vector 2'}
|
mock_point3.payload = {'chunk_id': 'another document from vector 2'}
|
||||||
|
|
||||||
mock_response1 = MagicMock()
|
mock_response1 = MagicMock()
|
||||||
mock_response1.points = [mock_point1]
|
mock_response1.points = [mock_point1]
|
||||||
|
|
@ -192,7 +192,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
mock_points = []
|
mock_points = []
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
mock_point = MagicMock()
|
mock_point = MagicMock()
|
||||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
mock_point.payload = {'chunk_id': f'document chunk {i}'}
|
||||||
mock_points.append(mock_point)
|
mock_points.append(mock_point)
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
|
@ -270,9 +270,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Mock query responses
|
# Mock query responses
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'doc': 'document from 2D vector'}
|
mock_point1.payload = {'chunk_id': 'document from 2D vector'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {'doc': 'document from 3D vector'}
|
mock_point2.payload = {'chunk_id': 'document from 3D vector'}
|
||||||
|
|
||||||
mock_response1 = MagicMock()
|
mock_response1 = MagicMock()
|
||||||
mock_response1.points = [mock_point1]
|
mock_response1.points = [mock_point1]
|
||||||
|
|
@ -326,9 +326,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Mock query response with UTF-8 content
|
# Mock query response with UTF-8 content
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
mock_point1.payload = {'chunk_id': 'Document with UTF-8: café, naïve, résumé'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
mock_point2.payload = {'chunk_id': 'Chinese text: 你好世界'}
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.points = [mock_point1, mock_point2]
|
mock_response.points = [mock_point1, mock_point2]
|
||||||
|
|
@ -399,7 +399,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Mock query response
|
# Mock query response
|
||||||
mock_point = MagicMock()
|
mock_point = MagicMock()
|
||||||
mock_point.payload = {'doc': 'document chunk'}
|
mock_point.payload = {'chunk_id': 'document chunk'}
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.points = [mock_point]
|
mock_response.points = [mock_point]
|
||||||
mock_qdrant_instance.query_points.return_value = mock_response
|
mock_qdrant_instance.query_points.return_value = mock_response
|
||||||
|
|
@ -442,9 +442,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Mock query response with fewer results than limit
|
# Mock query response with fewer results than limit
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'doc': 'document 1'}
|
mock_point1.payload = {'chunk_id': 'document 1'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {'doc': 'document 2'}
|
mock_point2.payload = {'chunk_id': 'document 2'}
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.points = [mock_point1, mock_point2]
|
mock_response.points = [mock_point1, mock_point2]
|
||||||
|
|
@ -487,11 +487,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
# Mock query response with missing 'doc' key
|
# Mock query response with missing 'chunk_id' key
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'doc': 'valid document'}
|
mock_point1.payload = {'chunk_id': 'valid document'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {} # Missing 'doc' key
|
mock_point2.payload = {} # Missing 'chunk_id' key
|
||||||
mock_point3 = MagicMock()
|
mock_point3 = MagicMock()
|
||||||
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
|
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
|
||||||
|
|
||||||
|
|
@ -514,7 +514,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
mock_message.collection = 'payload_collection'
|
mock_message.collection = 'payload_collection'
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
# This should raise a KeyError when trying to access payload['doc']
|
# This should raise a KeyError when trying to access payload['chunk_id']
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
await processor.query_document_embeddings(mock_message)
|
await processor.query_document_embeddings(mock_message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,48 +8,75 @@ from unittest.mock import MagicMock, AsyncMock
|
||||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
||||||
|
|
||||||
|
|
||||||
|
# Sample chunk content mapping for tests
|
||||||
|
CHUNK_CONTENT = {
|
||||||
|
"doc/c1": "Document 1 content",
|
||||||
|
"doc/c2": "Document 2 content",
|
||||||
|
"doc/c3": "Relevant document content",
|
||||||
|
"doc/c4": "Another document",
|
||||||
|
"doc/c5": "Default doc",
|
||||||
|
"doc/c6": "Verbose test doc",
|
||||||
|
"doc/c7": "Verbose doc content",
|
||||||
|
"doc/ml1": "Machine learning is a subset of artificial intelligence...",
|
||||||
|
"doc/ml2": "ML algorithms learn patterns from data to make predictions...",
|
||||||
|
"doc/ml3": "Common ML techniques include supervised and unsupervised learning...",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_fetch_chunk():
|
||||||
|
"""Create a mock fetch_chunk function"""
|
||||||
|
async def fetch(chunk_id, user):
|
||||||
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
|
return fetch
|
||||||
|
|
||||||
|
|
||||||
class TestDocumentRag:
|
class TestDocumentRag:
|
||||||
"""Test cases for DocumentRag class"""
|
"""Test cases for DocumentRag class"""
|
||||||
|
|
||||||
def test_document_rag_initialization_with_defaults(self):
|
def test_document_rag_initialization_with_defaults(self, mock_fetch_chunk):
|
||||||
"""Test DocumentRag initialization with default verbose setting"""
|
"""Test DocumentRag initialization with default verbose setting"""
|
||||||
# Create mock clients
|
# Create mock clients
|
||||||
mock_prompt_client = MagicMock()
|
mock_prompt_client = MagicMock()
|
||||||
mock_embeddings_client = MagicMock()
|
mock_embeddings_client = MagicMock()
|
||||||
mock_doc_embeddings_client = MagicMock()
|
mock_doc_embeddings_client = MagicMock()
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify initialization
|
# Verify initialization
|
||||||
assert document_rag.prompt_client == mock_prompt_client
|
assert document_rag.prompt_client == mock_prompt_client
|
||||||
assert document_rag.embeddings_client == mock_embeddings_client
|
assert document_rag.embeddings_client == mock_embeddings_client
|
||||||
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
||||||
|
assert document_rag.fetch_chunk == mock_fetch_chunk
|
||||||
assert document_rag.verbose is False # Default value
|
assert document_rag.verbose is False # Default value
|
||||||
|
|
||||||
def test_document_rag_initialization_with_verbose(self):
|
def test_document_rag_initialization_with_verbose(self, mock_fetch_chunk):
|
||||||
"""Test DocumentRag initialization with verbose enabled"""
|
"""Test DocumentRag initialization with verbose enabled"""
|
||||||
# Create mock clients
|
# Create mock clients
|
||||||
mock_prompt_client = MagicMock()
|
mock_prompt_client = MagicMock()
|
||||||
mock_embeddings_client = MagicMock()
|
mock_embeddings_client = MagicMock()
|
||||||
mock_doc_embeddings_client = MagicMock()
|
mock_doc_embeddings_client = MagicMock()
|
||||||
|
|
||||||
# Initialize DocumentRag with verbose=True
|
# Initialize DocumentRag with verbose=True
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify initialization
|
# Verify initialization
|
||||||
assert document_rag.prompt_client == mock_prompt_client
|
assert document_rag.prompt_client == mock_prompt_client
|
||||||
assert document_rag.embeddings_client == mock_embeddings_client
|
assert document_rag.embeddings_client == mock_embeddings_client
|
||||||
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
||||||
|
assert document_rag.fetch_chunk == mock_fetch_chunk
|
||||||
assert document_rag.verbose is True
|
assert document_rag.verbose is True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -60,7 +87,7 @@ class TestQuery:
|
||||||
"""Test Query initialization with default parameters"""
|
"""Test Query initialization with default parameters"""
|
||||||
# Create mock DocumentRag
|
# Create mock DocumentRag
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
|
|
||||||
# Initialize Query with defaults
|
# Initialize Query with defaults
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
|
|
@ -68,7 +95,7 @@ class TestQuery:
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify initialization
|
# Verify initialization
|
||||||
assert query.rag == mock_rag
|
assert query.rag == mock_rag
|
||||||
assert query.user == "test_user"
|
assert query.user == "test_user"
|
||||||
|
|
@ -80,7 +107,7 @@ class TestQuery:
|
||||||
"""Test Query initialization with custom doc_limit"""
|
"""Test Query initialization with custom doc_limit"""
|
||||||
# Create mock DocumentRag
|
# Create mock DocumentRag
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
|
|
||||||
# Initialize Query with custom doc_limit
|
# Initialize Query with custom doc_limit
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
|
|
@ -89,7 +116,7 @@ class TestQuery:
|
||||||
verbose=True,
|
verbose=True,
|
||||||
doc_limit=50
|
doc_limit=50
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify initialization
|
# Verify initialization
|
||||||
assert query.rag == mock_rag
|
assert query.rag == mock_rag
|
||||||
assert query.user == "custom_user"
|
assert query.user == "custom_user"
|
||||||
|
|
@ -104,11 +131,11 @@ class TestQuery:
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
|
|
||||||
# Mock the embed method to return test vectors
|
# Mock the embed method to return test vectors
|
||||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
mock_embeddings_client.embed.return_value = expected_vectors
|
mock_embeddings_client.embed.return_value = expected_vectors
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
|
|
@ -116,14 +143,14 @@ class TestQuery:
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call get_vector
|
# Call get_vector
|
||||||
test_query = "What documents are relevant?"
|
test_query = "What documents are relevant?"
|
||||||
result = await query.get_vector(test_query)
|
result = await query.get_vector(test_query)
|
||||||
|
|
||||||
# Verify embeddings client was called correctly
|
# Verify embeddings client was called correctly
|
||||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||||
|
|
||||||
# Verify result matches expected vectors
|
# Verify result matches expected vectors
|
||||||
assert result == expected_vectors
|
assert result == expected_vectors
|
||||||
|
|
||||||
|
|
@ -136,15 +163,20 @@ class TestQuery:
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
|
# Mock fetch_chunk function
|
||||||
|
async def mock_fetch(chunk_id, user):
|
||||||
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
# Mock the embedding and document query responses
|
# Mock the embedding and document query responses
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
mock_embeddings_client.embed.return_value = test_vectors
|
mock_embeddings_client.embed.return_value = test_vectors
|
||||||
|
|
||||||
# Mock document results
|
# Mock document embeddings returns chunk_ids
|
||||||
test_docs = ["Document 1 content", "Document 2 content"]
|
test_chunk_ids = ["doc/c1", "doc/c2"]
|
||||||
mock_doc_embeddings_client.query.return_value = test_docs
|
mock_doc_embeddings_client.query.return_value = test_chunk_ids
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
|
|
@ -153,14 +185,14 @@ class TestQuery:
|
||||||
verbose=False,
|
verbose=False,
|
||||||
doc_limit=15
|
doc_limit=15
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call get_docs
|
# Call get_docs
|
||||||
test_query = "Find relevant documents"
|
test_query = "Find relevant documents"
|
||||||
result = await query.get_docs(test_query)
|
result = await query.get_docs(test_query)
|
||||||
|
|
||||||
# Verify embeddings client was called
|
# Verify embeddings client was called
|
||||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||||
|
|
||||||
# Verify doc embeddings client was called correctly
|
# Verify doc embeddings client was called correctly
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
test_vectors,
|
test_vectors,
|
||||||
|
|
@ -168,35 +200,37 @@ class TestQuery:
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result is list of documents
|
# Verify result is list of fetched document content
|
||||||
assert result == test_docs
|
assert "Document 1 content" in result
|
||||||
|
assert "Document 2 content" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_query_method(self):
|
async def test_document_rag_query_method(self, mock_fetch_chunk):
|
||||||
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
|
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
|
||||||
# Create mock clients
|
# Create mock clients
|
||||||
mock_prompt_client = AsyncMock()
|
mock_prompt_client = AsyncMock()
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock embeddings and document responses
|
# Mock embeddings and document embeddings responses
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
test_docs = ["Relevant document content", "Another document"]
|
test_chunk_ids = ["doc/c3", "doc/c4"]
|
||||||
expected_response = "This is the document RAG response"
|
expected_response = "This is the document RAG response"
|
||||||
|
|
||||||
mock_embeddings_client.embed.return_value = test_vectors
|
mock_embeddings_client.embed.return_value = test_vectors
|
||||||
mock_doc_embeddings_client.query.return_value = test_docs
|
mock_doc_embeddings_client.query.return_value = test_chunk_ids
|
||||||
mock_prompt_client.document_prompt.return_value = expected_response
|
mock_prompt_client.document_prompt.return_value = expected_response
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call DocumentRag.query
|
# Call DocumentRag.query
|
||||||
result = await document_rag.query(
|
result = await document_rag.query(
|
||||||
query="test query",
|
query="test query",
|
||||||
|
|
@ -204,10 +238,10 @@ class TestQuery:
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
doc_limit=10
|
doc_limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify embeddings client was called
|
# Verify embeddings client was called
|
||||||
mock_embeddings_client.embed.assert_called_once_with("test query")
|
mock_embeddings_client.embed.assert_called_once_with("test query")
|
||||||
|
|
||||||
# Verify doc embeddings client was called
|
# Verify doc embeddings client was called
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
test_vectors,
|
test_vectors,
|
||||||
|
|
@ -215,39 +249,43 @@ class TestQuery:
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify prompt client was called with documents and query
|
# Verify prompt client was called with fetched documents and query
|
||||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
mock_prompt_client.document_prompt.assert_called_once()
|
||||||
query="test query",
|
call_args = mock_prompt_client.document_prompt.call_args
|
||||||
documents=test_docs
|
assert call_args.kwargs["query"] == "test query"
|
||||||
)
|
# Documents should be fetched content, not chunk_ids
|
||||||
|
docs = call_args.kwargs["documents"]
|
||||||
|
assert "Relevant document content" in docs
|
||||||
|
assert "Another document" in docs
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result == expected_response
|
assert result == expected_response
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_query_with_defaults(self):
|
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
|
||||||
"""Test DocumentRag.query method with default parameters"""
|
"""Test DocumentRag.query method with default parameters"""
|
||||||
# Create mock clients
|
# Create mock clients
|
||||||
mock_prompt_client = AsyncMock()
|
mock_prompt_client = AsyncMock()
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock responses
|
# Mock responses
|
||||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["Default doc"]
|
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
|
||||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call DocumentRag.query with minimal parameters
|
# Call DocumentRag.query with minimal parameters
|
||||||
result = await document_rag.query("simple query")
|
result = await document_rag.query("simple query")
|
||||||
|
|
||||||
# Verify default parameters were used
|
# Verify default parameters were used
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
[[0.1, 0.2]],
|
[[0.1, 0.2]],
|
||||||
|
|
@ -255,7 +293,7 @@ class TestQuery:
|
||||||
user="trustgraph", # Default user
|
user="trustgraph", # Default user
|
||||||
collection="default" # Default collection
|
collection="default" # Default collection
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == "Default response"
|
assert result == "Default response"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -267,11 +305,16 @@ class TestQuery:
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
|
# Mock fetch_chunk
|
||||||
|
async def mock_fetch(chunk_id, user):
|
||||||
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
# Mock responses
|
# Mock responses
|
||||||
mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
|
mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["Verbose test doc"]
|
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
|
||||||
|
|
||||||
# Initialize Query with verbose=True
|
# Initialize Query with verbose=True
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
|
|
@ -280,49 +323,51 @@ class TestQuery:
|
||||||
verbose=True,
|
verbose=True,
|
||||||
doc_limit=5
|
doc_limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call get_docs
|
# Call get_docs
|
||||||
result = await query.get_docs("verbose test")
|
result = await query.get_docs("verbose test")
|
||||||
|
|
||||||
# Verify calls were made
|
# Verify calls were made
|
||||||
mock_embeddings_client.embed.assert_called_once_with("verbose test")
|
mock_embeddings_client.embed.assert_called_once_with("verbose test")
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
|
|
||||||
# Verify result
|
# Verify result contains fetched content
|
||||||
assert result == ["Verbose test doc"]
|
assert "Verbose test doc" in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_query_with_verbose(self):
|
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
|
||||||
"""Test DocumentRag.query method with verbose logging enabled"""
|
"""Test DocumentRag.query method with verbose logging enabled"""
|
||||||
# Create mock clients
|
# Create mock clients
|
||||||
mock_prompt_client = AsyncMock()
|
mock_prompt_client = AsyncMock()
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock responses
|
# Mock responses
|
||||||
mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
|
mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["Verbose doc content"]
|
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
|
||||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||||
|
|
||||||
# Initialize DocumentRag with verbose=True
|
# Initialize DocumentRag with verbose=True
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call DocumentRag.query
|
# Call DocumentRag.query
|
||||||
result = await document_rag.query("verbose query test")
|
result = await document_rag.query("verbose query test")
|
||||||
|
|
||||||
# Verify all clients were called
|
# Verify all clients were called
|
||||||
mock_embeddings_client.embed.assert_called_once_with("verbose query test")
|
mock_embeddings_client.embed.assert_called_once_with("verbose query test")
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
|
||||||
query="verbose query test",
|
# Verify prompt client was called with fetched content
|
||||||
documents=["Verbose doc content"]
|
call_args = mock_prompt_client.document_prompt.call_args
|
||||||
)
|
assert call_args.kwargs["query"] == "verbose query test"
|
||||||
|
assert "Verbose doc content" in call_args.kwargs["documents"]
|
||||||
|
|
||||||
assert result == "Verbose RAG response"
|
assert result == "Verbose RAG response"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -334,11 +379,16 @@ class TestQuery:
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
# Mock responses - empty document list
|
# Mock fetch_chunk (won't be called if no chunk_ids)
|
||||||
|
async def mock_fetch(chunk_id, user):
|
||||||
|
return f"Content for {chunk_id}"
|
||||||
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
# Mock responses - empty chunk_id list
|
||||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
|
|
@ -346,47 +396,48 @@ class TestQuery:
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call get_docs
|
# Call get_docs
|
||||||
result = await query.get_docs("query with no results")
|
result = await query.get_docs("query with no results")
|
||||||
|
|
||||||
# Verify calls were made
|
# Verify calls were made
|
||||||
mock_embeddings_client.embed.assert_called_once_with("query with no results")
|
mock_embeddings_client.embed.assert_called_once_with("query with no results")
|
||||||
mock_doc_embeddings_client.query.assert_called_once()
|
mock_doc_embeddings_client.query.assert_called_once()
|
||||||
|
|
||||||
# Verify empty result is returned
|
# Verify empty result is returned
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_query_with_empty_documents(self):
|
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
|
||||||
"""Test DocumentRag.query method when no documents are retrieved"""
|
"""Test DocumentRag.query method when no documents are retrieved"""
|
||||||
# Create mock clients
|
# Create mock clients
|
||||||
mock_prompt_client = AsyncMock()
|
mock_prompt_client = AsyncMock()
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock responses - no documents found
|
# Mock responses - no chunk_ids found
|
||||||
mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
|
mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
|
||||||
mock_doc_embeddings_client.query.return_value = [] # Empty document list
|
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
|
||||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call DocumentRag.query
|
# Call DocumentRag.query
|
||||||
result = await document_rag.query("query with no matching docs")
|
result = await document_rag.query("query with no matching docs")
|
||||||
|
|
||||||
# Verify prompt client was called with empty document list
|
# Verify prompt client was called with empty document list
|
||||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||||
query="query with no matching docs",
|
query="query with no matching docs",
|
||||||
documents=[]
|
documents=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == "No documents found response"
|
assert result == "No documents found response"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -396,11 +447,11 @@ class TestQuery:
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
|
|
||||||
# Mock the embed method
|
# Mock the embed method
|
||||||
expected_vectors = [[0.9, 1.0, 1.1]]
|
expected_vectors = [[0.9, 1.0, 1.1]]
|
||||||
mock_embeddings_client.embed.return_value = expected_vectors
|
mock_embeddings_client.embed.return_value = expected_vectors
|
||||||
|
|
||||||
# Initialize Query with verbose=True
|
# Initialize Query with verbose=True
|
||||||
query = Query(
|
query = Query(
|
||||||
rag=mock_rag,
|
rag=mock_rag,
|
||||||
|
|
@ -408,68 +459,71 @@ class TestQuery:
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
verbose=True
|
verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call get_vector
|
# Call get_vector
|
||||||
result = await query.get_vector("verbose vector test")
|
result = await query.get_vector("verbose vector test")
|
||||||
|
|
||||||
# Verify embeddings client was called
|
# Verify embeddings client was called
|
||||||
mock_embeddings_client.embed.assert_called_once_with("verbose vector test")
|
mock_embeddings_client.embed.assert_called_once_with("verbose vector test")
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
assert result == expected_vectors
|
assert result == expected_vectors
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_document_rag_integration_flow(self):
|
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
|
||||||
"""Test complete DocumentRag integration with realistic data flow"""
|
"""Test complete DocumentRag integration with realistic data flow"""
|
||||||
# Create mock clients
|
# Create mock clients
|
||||||
mock_prompt_client = AsyncMock()
|
mock_prompt_client = AsyncMock()
|
||||||
mock_embeddings_client = AsyncMock()
|
mock_embeddings_client = AsyncMock()
|
||||||
mock_doc_embeddings_client = AsyncMock()
|
mock_doc_embeddings_client = AsyncMock()
|
||||||
|
|
||||||
# Mock realistic responses
|
# Mock realistic responses
|
||||||
query_text = "What is machine learning?"
|
query_text = "What is machine learning?"
|
||||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
retrieved_docs = [
|
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
|
||||||
"Machine learning is a subset of artificial intelligence...",
|
|
||||||
"ML algorithms learn patterns from data to make predictions...",
|
|
||||||
"Common ML techniques include supervised and unsupervised learning..."
|
|
||||||
]
|
|
||||||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||||
|
|
||||||
mock_embeddings_client.embed.return_value = query_vectors
|
mock_embeddings_client.embed.return_value = query_vectors
|
||||||
mock_doc_embeddings_client.query.return_value = retrieved_docs
|
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
|
||||||
mock_prompt_client.document_prompt.return_value = final_response
|
mock_prompt_client.document_prompt.return_value = final_response
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
document_rag = DocumentRag(
|
document_rag = DocumentRag(
|
||||||
prompt_client=mock_prompt_client,
|
prompt_client=mock_prompt_client,
|
||||||
embeddings_client=mock_embeddings_client,
|
embeddings_client=mock_embeddings_client,
|
||||||
doc_embeddings_client=mock_doc_embeddings_client,
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
fetch_chunk=mock_fetch_chunk,
|
||||||
verbose=False
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute full pipeline
|
# Execute full pipeline
|
||||||
result = await document_rag.query(
|
result = await document_rag.query(
|
||||||
query=query_text,
|
query=query_text,
|
||||||
user="research_user",
|
user="research_user",
|
||||||
collection="ml_knowledge",
|
collection="ml_knowledge",
|
||||||
doc_limit=25
|
doc_limit=25
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify complete pipeline execution
|
# Verify complete pipeline execution
|
||||||
mock_embeddings_client.embed.assert_called_once_with(query_text)
|
mock_embeddings_client.embed.assert_called_once_with(query_text)
|
||||||
|
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
query_vectors,
|
query_vectors,
|
||||||
limit=25,
|
limit=25,
|
||||||
user="research_user",
|
user="research_user",
|
||||||
collection="ml_knowledge"
|
collection="ml_knowledge"
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
# Verify prompt client was called with fetched document content
|
||||||
query=query_text,
|
mock_prompt_client.document_prompt.assert_called_once()
|
||||||
documents=retrieved_docs
|
call_args = mock_prompt_client.document_prompt.call_args
|
||||||
)
|
assert call_args.kwargs["query"] == query_text
|
||||||
|
|
||||||
|
# Verify documents were fetched from chunk_ids
|
||||||
|
docs = call_args.kwargs["documents"]
|
||||||
|
assert "Machine learning is a subset of artificial intelligence..." in docs
|
||||||
|
assert "ML algorithms learn patterns from data to make predictions..." in docs
|
||||||
|
assert "Common ML techniques include supervised and unsupervised learning..." in docs
|
||||||
|
|
||||||
# Verify final result
|
# Verify final result
|
||||||
assert result == final_response
|
assert result == final_response
|
||||||
|
|
|
||||||
|
|
@ -22,11 +22,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
# Create test document embeddings
|
# Create test document embeddings
|
||||||
chunk1 = ChunkEmbeddings(
|
chunk1 = ChunkEmbeddings(
|
||||||
chunk=b"This is the first document chunk",
|
chunk_id="This is the first document chunk",
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
)
|
)
|
||||||
chunk2 = ChunkEmbeddings(
|
chunk2 = ChunkEmbeddings(
|
||||||
chunk=b"This is the second document chunk",
|
chunk_id="This is the second document chunk",
|
||||||
vectors=[[0.7, 0.8, 0.9]]
|
vectors=[[0.7, 0.8, 0.9]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk1, chunk2]
|
message.chunks = [chunk1, chunk2]
|
||||||
|
|
@ -84,7 +84,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Test document content",
|
chunk_id="Test document content",
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
@ -136,7 +136,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"",
|
chunk_id="",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
@ -148,51 +148,62 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_document_embeddings_none_chunk(self, processor):
|
async def test_store_document_embeddings_none_chunk(self, processor):
|
||||||
"""Test storing document embeddings with None chunk (should be skipped)"""
|
"""Test storing document embeddings with None chunk_id"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=None,
|
chunk_id=None,
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
await processor.store_document_embeddings(message)
|
await processor.store_document_embeddings(message)
|
||||||
|
|
||||||
# Verify no insert was called for None chunk
|
# Note: Implementation passes through None chunk_ids (only skips empty string "")
|
||||||
processor.vecstore.insert.assert_not_called()
|
processor.vecstore.insert.assert_called_once_with(
|
||||||
|
[0.1, 0.2, 0.3], None, 'test_user', 'test_collection'
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
|
async def test_store_document_embeddings_mixed_valid_invalid_chunks(self, processor):
|
||||||
"""Test storing document embeddings with mix of valid and invalid chunks"""
|
"""Test storing document embeddings with mix of valid and empty chunks"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
valid_chunk = ChunkEmbeddings(
|
valid_chunk = ChunkEmbeddings(
|
||||||
chunk=b"Valid document content",
|
chunk_id="Valid document content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
empty_chunk = ChunkEmbeddings(
|
empty_chunk = ChunkEmbeddings(
|
||||||
chunk=b"",
|
chunk_id="",
|
||||||
vectors=[[0.4, 0.5, 0.6]]
|
vectors=[[0.4, 0.5, 0.6]]
|
||||||
)
|
)
|
||||||
none_chunk = ChunkEmbeddings(
|
another_valid = ChunkEmbeddings(
|
||||||
chunk=None,
|
chunk_id="Another valid chunk",
|
||||||
vectors=[[0.7, 0.8, 0.9]]
|
vectors=[[0.7, 0.8, 0.9]]
|
||||||
)
|
)
|
||||||
message.chunks = [valid_chunk, empty_chunk, none_chunk]
|
message.chunks = [valid_chunk, empty_chunk, another_valid]
|
||||||
|
|
||||||
await processor.store_document_embeddings(message)
|
await processor.store_document_embeddings(message)
|
||||||
|
|
||||||
# Verify only valid chunk was inserted with user/collection parameters
|
# Verify valid chunks were inserted, empty string chunk was skipped
|
||||||
processor.vecstore.insert.assert_called_once_with(
|
expected_calls = [
|
||||||
[0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection'
|
([0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection'),
|
||||||
)
|
([0.7, 0.8, 0.9], "Another valid chunk", 'test_user', 'test_collection'),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert processor.vecstore.insert.call_count == 2
|
||||||
|
for i, (expected_vec, expected_chunk_id, expected_user, expected_collection) in enumerate(expected_calls):
|
||||||
|
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||||
|
assert actual_call[0][0] == expected_vec
|
||||||
|
assert actual_call[0][1] == expected_chunk_id
|
||||||
|
assert actual_call[0][2] == expected_user
|
||||||
|
assert actual_call[0][3] == expected_collection
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
||||||
|
|
@ -217,7 +228,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Document with no vectors",
|
chunk_id="Document with no vectors",
|
||||||
vectors=[]
|
vectors=[]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
@ -236,7 +247,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Document with mixed dimensions",
|
chunk_id="Document with mixed dimensions",
|
||||||
vectors=[
|
vectors=[
|
||||||
[0.1, 0.2], # 2D vector
|
[0.1, 0.2], # 2D vector
|
||||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||||
|
|
@ -264,46 +275,46 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||||
"""Test storing document embeddings with Unicode content"""
|
"""Test storing document embeddings with Unicode content in chunk_id"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
chunk_id="chunk/doc/unicode-éñ中文🚀",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
await processor.store_document_embeddings(message)
|
await processor.store_document_embeddings(message)
|
||||||
|
|
||||||
# Verify Unicode content was properly decoded and inserted with user/collection parameters
|
# Verify Unicode chunk_id was stored correctly with user/collection parameters
|
||||||
processor.vecstore.insert.assert_called_once_with(
|
processor.vecstore.insert.assert_called_once_with(
|
||||||
[0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀", 'test_user', 'test_collection'
|
[0.1, 0.2, 0.3], "chunk/doc/unicode-éñ中文🚀", 'test_user', 'test_collection'
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_document_embeddings_large_chunks(self, processor):
|
async def test_store_document_embeddings_large_chunk_id(self, processor):
|
||||||
"""Test storing document embeddings with large document chunks"""
|
"""Test storing document embeddings with long chunk_id"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
# Create a large document chunk
|
# Create a long chunk_id
|
||||||
large_content = "A" * 10000 # 10KB of content
|
long_chunk_id = "chunk/doc/" + "a" * 200
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=large_content.encode('utf-8'),
|
chunk_id=long_chunk_id,
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
await processor.store_document_embeddings(message)
|
await processor.store_document_embeddings(message)
|
||||||
|
|
||||||
# Verify large content was inserted with user/collection parameters
|
# Verify long chunk_id was inserted with user/collection parameters
|
||||||
processor.vecstore.insert.assert_called_once_with(
|
processor.vecstore.insert.assert_called_once_with(
|
||||||
[0.1, 0.2, 0.3], large_content, 'test_user', 'test_collection'
|
[0.1, 0.2, 0.3], long_chunk_id, 'test_user', 'test_collection'
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -315,7 +326,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b" \n\t ",
|
chunk_id=" \n\t ",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
@ -346,7 +357,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.collection = collection
|
message.metadata.collection = collection
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Test content",
|
chunk_id="Test content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
@ -367,7 +378,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message1.metadata.user = 'user1'
|
message1.metadata.user = 'user1'
|
||||||
message1.metadata.collection = 'collection1'
|
message1.metadata.collection = 'collection1'
|
||||||
chunk1 = ChunkEmbeddings(
|
chunk1 = ChunkEmbeddings(
|
||||||
chunk=b"User1 content",
|
chunk_id="User1 content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message1.chunks = [chunk1]
|
message1.chunks = [chunk1]
|
||||||
|
|
@ -378,7 +389,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message2.metadata.user = 'user2'
|
message2.metadata.user = 'user2'
|
||||||
message2.metadata.collection = 'collection2'
|
message2.metadata.collection = 'collection2'
|
||||||
chunk2 = ChunkEmbeddings(
|
chunk2 = ChunkEmbeddings(
|
||||||
chunk=b"User2 content",
|
chunk_id="User2 content",
|
||||||
vectors=[[0.4, 0.5, 0.6]]
|
vectors=[[0.4, 0.5, 0.6]]
|
||||||
)
|
)
|
||||||
message2.chunks = [chunk2]
|
message2.chunks = [chunk2]
|
||||||
|
|
@ -409,7 +420,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.collection = 'test-collection.v1' # Collection with special chars
|
message.metadata.collection = 'test-collection.v1' # Collection with special chars
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Special chars test",
|
chunk_id="Special chars test",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'store_uri': 'http://localhost:6333',
|
'store_uri': 'http://localhost:6333',
|
||||||
'api_key': 'test-api-key',
|
'api_key': 'test-api-key',
|
||||||
|
|
@ -34,7 +34,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Assert
|
# Assert
|
||||||
# Verify QdrantClient was created with correct parameters
|
# Verify QdrantClient was created with correct parameters
|
||||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||||
|
|
||||||
# Verify processor attributes
|
# Verify processor attributes
|
||||||
assert hasattr(processor, 'qdrant')
|
assert hasattr(processor, 'qdrant')
|
||||||
assert processor.qdrant == mock_qdrant_instance
|
assert processor.qdrant == mock_qdrant_instance
|
||||||
|
|
@ -45,7 +45,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'taskgroup': AsyncMock(),
|
'taskgroup': AsyncMock(),
|
||||||
'id': 'test-doc-qdrant-processor'
|
'id': 'test-doc-qdrant-processor'
|
||||||
|
|
@ -69,7 +69,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
mock_uuid.uuid4.return_value = MagicMock()
|
mock_uuid.uuid4.return_value = MagicMock()
|
||||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'store_uri': 'http://localhost:6333',
|
'store_uri': 'http://localhost:6333',
|
||||||
'api_key': 'test-api-key',
|
'api_key': 'test-api-key',
|
||||||
|
|
@ -86,13 +86,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'test_user'
|
mock_message.metadata.user = 'test_user'
|
||||||
mock_message.metadata.collection = 'test_collection'
|
mock_message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk.decode.return_value = 'test document chunk'
|
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
|
||||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
|
|
@ -100,18 +100,18 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Verify collection existence was checked (with dimension suffix)
|
# Verify collection existence was checked (with dimension suffix)
|
||||||
expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3]
|
expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3]
|
||||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||||
|
|
||||||
# Verify upsert was called
|
# Verify upsert was called
|
||||||
mock_qdrant_instance.upsert.assert_called_once()
|
mock_qdrant_instance.upsert.assert_called_once()
|
||||||
|
|
||||||
# Verify upsert parameters
|
# Verify upsert parameters
|
||||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||||
assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3'
|
assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3'
|
||||||
assert len(upsert_call_args[1]['points']) == 1
|
assert len(upsert_call_args[1]['points']) == 1
|
||||||
|
|
||||||
point = upsert_call_args[1]['points'][0]
|
point = upsert_call_args[1]['points'][0]
|
||||||
assert point.vector == [0.1, 0.2, 0.3]
|
assert point.vector == [0.1, 0.2, 0.3]
|
||||||
assert point.payload['doc'] == 'test document chunk'
|
assert point.payload['chunk_id'] == 'doc/c1'
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
|
|
@ -123,7 +123,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
mock_uuid.uuid4.return_value = MagicMock()
|
mock_uuid.uuid4.return_value = MagicMock()
|
||||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'store_uri': 'http://localhost:6333',
|
'store_uri': 'http://localhost:6333',
|
||||||
'api_key': 'test-api-key',
|
'api_key': 'test-api-key',
|
||||||
|
|
@ -140,38 +140,38 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'multi_user'
|
mock_message.metadata.user = 'multi_user'
|
||||||
mock_message.metadata.collection = 'multi_collection'
|
mock_message.metadata.collection = 'multi_collection'
|
||||||
|
|
||||||
mock_chunk1 = MagicMock()
|
mock_chunk1 = MagicMock()
|
||||||
mock_chunk1.chunk.decode.return_value = 'first document chunk'
|
mock_chunk1.chunk_id = 'doc/c1'
|
||||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||||
|
|
||||||
mock_chunk2 = MagicMock()
|
mock_chunk2 = MagicMock()
|
||||||
mock_chunk2.chunk.decode.return_value = 'second document chunk'
|
mock_chunk2.chunk_id = 'doc/c2'
|
||||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Should be called twice (once per chunk)
|
# Should be called twice (once per chunk)
|
||||||
assert mock_qdrant_instance.upsert.call_count == 2
|
assert mock_qdrant_instance.upsert.call_count == 2
|
||||||
|
|
||||||
# Verify both chunks were processed
|
# Verify both chunks were processed
|
||||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||||
|
|
||||||
# First chunk
|
# First chunk
|
||||||
first_call = upsert_calls[0]
|
first_call = upsert_calls[0]
|
||||||
first_point = first_call[1]['points'][0]
|
first_point = first_call[1]['points'][0]
|
||||||
assert first_point.vector == [0.1, 0.2]
|
assert first_point.vector == [0.1, 0.2]
|
||||||
assert first_point.payload['doc'] == 'first document chunk'
|
assert first_point.payload['chunk_id'] == 'doc/c1'
|
||||||
|
|
||||||
# Second chunk
|
# Second chunk
|
||||||
second_call = upsert_calls[1]
|
second_call = upsert_calls[1]
|
||||||
second_point = second_call[1]['points'][0]
|
second_point = second_call[1]['points'][0]
|
||||||
assert second_point.vector == [0.3, 0.4]
|
assert second_point.vector == [0.3, 0.4]
|
||||||
assert second_point.payload['doc'] == 'second document chunk'
|
assert second_point.payload['chunk_id'] == 'doc/c2'
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
|
|
@ -183,7 +183,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
mock_uuid.uuid4.return_value = MagicMock()
|
mock_uuid.uuid4.return_value = MagicMock()
|
||||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'store_uri': 'http://localhost:6333',
|
'store_uri': 'http://localhost:6333',
|
||||||
'api_key': 'test-api-key',
|
'api_key': 'test-api-key',
|
||||||
|
|
@ -200,41 +200,41 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'vector_user'
|
mock_message.metadata.user = 'vector_user'
|
||||||
mock_message.metadata.collection = 'vector_collection'
|
mock_message.metadata.collection = 'vector_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
|
mock_chunk.chunk_id = 'doc/multi-vector'
|
||||||
mock_chunk.vectors = [
|
mock_chunk.vectors = [
|
||||||
[0.1, 0.2, 0.3],
|
[0.1, 0.2, 0.3],
|
||||||
[0.4, 0.5, 0.6],
|
[0.4, 0.5, 0.6],
|
||||||
[0.7, 0.8, 0.9]
|
[0.7, 0.8, 0.9]
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Should be called 3 times (once per vector)
|
# Should be called 3 times (once per vector)
|
||||||
assert mock_qdrant_instance.upsert.call_count == 3
|
assert mock_qdrant_instance.upsert.call_count == 3
|
||||||
|
|
||||||
# Verify all vectors were processed
|
# Verify all vectors were processed
|
||||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||||
|
|
||||||
expected_vectors = [
|
expected_vectors = [
|
||||||
[0.1, 0.2, 0.3],
|
[0.1, 0.2, 0.3],
|
||||||
[0.4, 0.5, 0.6],
|
[0.4, 0.5, 0.6],
|
||||||
[0.7, 0.8, 0.9]
|
[0.7, 0.8, 0.9]
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, call in enumerate(upsert_calls):
|
for i, call in enumerate(upsert_calls):
|
||||||
point = call[1]['points'][0]
|
point = call[1]['points'][0]
|
||||||
assert point.vector == expected_vectors[i]
|
assert point.vector == expected_vectors[i]
|
||||||
assert point.payload['doc'] == 'multi-vector document chunk'
|
assert point.payload['chunk_id'] == 'doc/multi-vector'
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client):
|
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
|
||||||
"""Test storing document embeddings skips empty chunks"""
|
"""Test storing document embeddings skips empty chunk_ids"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||||
|
|
@ -249,13 +249,13 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Create mock message with empty chunk
|
# Create mock message with empty chunk_id
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'empty_user'
|
mock_message.metadata.user = 'empty_user'
|
||||||
mock_message.metadata.collection = 'empty_collection'
|
mock_message.metadata.collection = 'empty_collection'
|
||||||
|
|
||||||
mock_chunk_empty = MagicMock()
|
mock_chunk_empty = MagicMock()
|
||||||
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
|
mock_chunk_empty.chunk_id = "" # Empty chunk_id
|
||||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk_empty]
|
mock_message.chunks = [mock_chunk_empty]
|
||||||
|
|
@ -264,9 +264,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Should not call upsert for empty chunks
|
# Should not call upsert for empty chunk_ids
|
||||||
mock_qdrant_instance.upsert.assert_not_called()
|
mock_qdrant_instance.upsert.assert_not_called()
|
||||||
# collection_exists should NOT be called since we return early for empty chunks
|
# collection_exists should NOT be called since we return early for empty chunk_ids
|
||||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
|
|
@ -298,7 +298,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message.metadata.collection = 'new_collection'
|
mock_message.metadata.collection = 'new_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
mock_chunk.chunk_id = 'doc/test-chunk'
|
||||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
@ -350,7 +350,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message.metadata.collection = 'error_collection'
|
mock_message.metadata.collection = 'error_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
mock_chunk.chunk_id = 'doc/test-chunk'
|
||||||
mock_chunk.vectors = [[0.1, 0.2]]
|
mock_chunk.vectors = [[0.1, 0.2]]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
@ -388,7 +388,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message1.metadata.collection = 'cache_collection'
|
mock_message1.metadata.collection = 'cache_collection'
|
||||||
|
|
||||||
mock_chunk1 = MagicMock()
|
mock_chunk1 = MagicMock()
|
||||||
mock_chunk1.chunk.decode.return_value = 'first chunk'
|
mock_chunk1.chunk_id = 'doc/c1'
|
||||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
|
||||||
mock_message1.chunks = [mock_chunk1]
|
mock_message1.chunks = [mock_chunk1]
|
||||||
|
|
@ -406,7 +406,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message2.metadata.collection = 'cache_collection'
|
mock_message2.metadata.collection = 'cache_collection'
|
||||||
|
|
||||||
mock_chunk2 = MagicMock()
|
mock_chunk2 = MagicMock()
|
||||||
mock_chunk2.chunk.decode.return_value = 'second chunk'
|
mock_chunk2.chunk_id = 'doc/c2'
|
||||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
||||||
|
|
||||||
mock_message2.chunks = [mock_chunk2]
|
mock_message2.chunks = [mock_chunk2]
|
||||||
|
|
@ -452,7 +452,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_message.metadata.collection = 'dim_collection'
|
mock_message.metadata.collection = 'dim_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
|
mock_chunk.chunk_id = 'doc/dim-test'
|
||||||
mock_chunk.vectors = [
|
mock_chunk.vectors = [
|
||||||
[0.1, 0.2], # 2 dimensions
|
[0.1, 0.2], # 2 dimensions
|
||||||
[0.3, 0.4, 0.5] # 3 dimensions
|
[0.3, 0.4, 0.5] # 3 dimensions
|
||||||
|
|
@ -485,28 +485,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_qdrant_client.return_value = MagicMock()
|
mock_qdrant_client.return_value = MagicMock()
|
||||||
mock_parser = MagicMock()
|
mock_parser = MagicMock()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||||
Processor.add_args(mock_parser)
|
Processor.add_args(mock_parser)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||||
|
|
||||||
# Verify processor-specific arguments were added
|
# Verify processor-specific arguments were added
|
||||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client):
|
async def test_chunk_id_with_special_characters(self, mock_uuid, mock_qdrant_client):
|
||||||
"""Test proper UTF-8 decoding of chunk text"""
|
"""Test storing chunk_id with special characters (URIs)"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
mock_uuid.uuid4.return_value = MagicMock()
|
mock_uuid.uuid4.return_value = MagicMock()
|
||||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'store_uri': 'http://localhost:6333',
|
'store_uri': 'http://localhost:6333',
|
||||||
'api_key': 'test-api-key',
|
'api_key': 'test-api-key',
|
||||||
|
|
@ -517,65 +517,28 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Add collection to known_collections (simulates config push)
|
# Add collection to known_collections (simulates config push)
|
||||||
processor.known_collections[('utf8_user', 'utf8_collection')] = {}
|
processor.known_collections[('uri_user', 'uri_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with UTF-8 encoded text
|
# Create mock message with URI-style chunk_id
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'utf8_user'
|
mock_message.metadata.user = 'uri_user'
|
||||||
mock_message.metadata.collection = 'utf8_collection'
|
mock_message.metadata.collection = 'uri_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé'
|
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
|
||||||
mock_chunk.vectors = [[0.1, 0.2]]
|
mock_chunk.vectors = [[0.1, 0.2]]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify chunk.decode was called with 'utf-8'
|
# Verify the chunk_id was stored correctly
|
||||||
mock_chunk.chunk.decode.assert_called_with('utf-8')
|
|
||||||
|
|
||||||
# Verify the decoded text was stored in payload
|
|
||||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||||
point = upsert_call_args[1]['points'][0]
|
point = upsert_call_args[1]['points'][0]
|
||||||
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
|
assert point.payload['chunk_id'] == 'https://trustgraph.ai/doc/my-document/p1/c3'
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
|
||||||
async def test_chunk_decode_exception_handling(self, mock_qdrant_client):
|
|
||||||
"""Test handling of chunk decode exceptions"""
|
|
||||||
# Arrange
|
|
||||||
mock_qdrant_instance = MagicMock()
|
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
|
||||||
|
|
||||||
config = {
|
|
||||||
'store_uri': 'http://localhost:6333',
|
|
||||||
'api_key': 'test-api-key',
|
|
||||||
'taskgroup': AsyncMock(),
|
|
||||||
'id': 'test-doc-qdrant-processor'
|
|
||||||
}
|
|
||||||
|
|
||||||
processor = Processor(**config)
|
|
||||||
|
|
||||||
# Add collection to known_collections (simulates config push)
|
|
||||||
processor.known_collections[('decode_user', 'decode_collection')] = {}
|
|
||||||
|
|
||||||
# Create mock message with decode error
|
|
||||||
mock_message = MagicMock()
|
|
||||||
mock_message.metadata.user = 'decode_user'
|
|
||||||
mock_message.metadata.collection = 'decode_collection'
|
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
|
||||||
mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte')
|
|
||||||
mock_chunk.vectors = [[0.1, 0.2]]
|
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
|
||||||
|
|
||||||
# Act & Assert
|
|
||||||
with pytest.raises(UnicodeDecodeError):
|
|
||||||
await processor.store_document_embeddings(mock_message)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue