mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.doc_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
|
||||
|
||||
|
||||
class TestMilvusDocEmbeddingsQueryProcessor:
|
||||
|
|
@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
|
@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
|
||||
)
|
||||
|
||||
# Verify results are document chunks
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert result[0] == "First document chunk"
|
||||
assert result[1] == "Second document chunk"
|
||||
assert result[2] == "Third document chunk"
|
||||
assert isinstance(result[0], ChunkMatch)
|
||||
assert result[0].chunk_id == "First document chunk"
|
||||
assert result[1].chunk_id == "Second document chunk"
|
||||
assert result[2].chunk_id == "Third document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
async def test_query_document_embeddings_longer_vector(self, processor):
|
||||
"""Test querying document embeddings with a longer vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=3
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
{"entity": {"chunk_id": "Document from first vector"}},
|
||||
{"entity": {"chunk_id": "Another doc from first vector"}},
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"chunk_id": "First document"}},
|
||||
{"entity": {"chunk_id": "Second document"}},
|
||||
{"entity": {"chunk_id": "Third document"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"chunk_id": "Document from second vector"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters including user/collection
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}),
|
||||
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results from all vectors are combined
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
|
||||
)
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert "Document from first vector" in result
|
||||
assert "Another doc from first vector" in result
|
||||
assert "Document from second vector" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "First document" in chunk_ids
|
||||
assert "Second document" in chunk_ids
|
||||
assert "Third document" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_with_limit(self, processor):
|
||||
|
|
@ -141,7 +135,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
)
|
||||
|
||||
|
|
@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -211,7 +205,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -225,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify Unicode content is preserved
|
||||
# Verify Unicode content is preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert "Document with Unicode: éñ中文🚀" in result
|
||||
assert "Regular ASCII document" in result
|
||||
assert "Document with émojis: 😀🎉" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document with Unicode: éñ中文🚀" in chunk_ids
|
||||
assert "Regular ASCII document" in chunk_ids
|
||||
assert "Document with émojis: 😀🎉" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_large_documents(self, processor):
|
||||
|
|
@ -237,7 +232,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -251,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify large content is preserved
|
||||
# Verify large content is preserved in ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
assert large_doc in result
|
||||
assert "Small document" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert large_doc in chunk_ids
|
||||
assert "Small document" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_special_characters(self, processor):
|
||||
|
|
@ -262,7 +258,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -276,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify special characters are preserved
|
||||
# Verify special characters are preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert "Document with \"quotes\" and 'apostrophes'" in result
|
||||
assert "Document with\nnewlines\tand\ttabs" in result
|
||||
assert "Document with special chars: @#$%^&*()" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
|
||||
assert "Document with\nnewlines\tand\ttabs" in chunk_ids
|
||||
assert "Document with special chars: @#$%^&*()" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
|
|
@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
|
|
@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=-1
|
||||
)
|
||||
|
||||
|
|
@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}]
|
||||
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}]
|
||||
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"chunk_id": "Document 1"}},
|
||||
{"entity": {"chunk_id": "Document 2"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
assert "Document from 2D vector" in result
|
||||
assert "Document from 4D vector" in result
|
||||
assert "Document from 3D vector" in result
|
||||
|
||||
# Verify search was called with the vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document 1" in chunk_ids
|
||||
assert "Document 2" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_duplicate_documents(self, processor):
|
||||
"""Test querying document embeddings with duplicate documents in results"""
|
||||
async def test_query_document_embeddings_multiple_results(self, processor):
|
||||
"""Test querying document embeddings with multiple results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates across vectors
|
||||
mock_results_1 = [
|
||||
|
||||
# Mock search results with multiple documents
|
||||
mock_results = [
|
||||
{"entity": {"chunk_id": "Document A"}},
|
||||
{"entity": {"chunk_id": "Document B"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"chunk_id": "Document B"}}, # Duplicate
|
||||
{"entity": {"chunk_id": "Document C"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
|
||||
# This preserves ranking and allows multiple occurrences
|
||||
assert len(result) == 4
|
||||
assert result.count("Document B") == 2 # Should appear twice
|
||||
assert "Document A" in result
|
||||
assert "Document C" in result
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document A" in chunk_ids
|
||||
assert "Document B" in chunk_ids
|
||||
assert "Document C" in chunk_ids
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -208,7 +208,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -226,7 +226,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -285,7 +285,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.vector = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -304,7 +304,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -325,7 +325,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_unicode_content(self, processor):
|
||||
"""Test querying document embeddings with Unicode content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -351,7 +351,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_large_content(self, processor):
|
||||
"""Test querying document embeddings with large content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -377,7 +377,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_mixed_content_types(self, processor):
|
||||
"""Test querying document embeddings with mixed content types"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -409,7 +409,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -425,7 +425,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||
"""Test handling of index access failure"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
||||
from trustgraph.schema import ChunkMatch
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
|
|
@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
|
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected documents
|
||||
# Verify result contains expected ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
# Results should be strings (document chunks)
|
||||
assert isinstance(result[0], str)
|
||||
assert isinstance(result[1], str)
|
||||
# Results should be ChunkMatch objects
|
||||
assert isinstance(result[0], ChunkMatch)
|
||||
assert isinstance(result[1], ChunkMatch)
|
||||
# Verify content
|
||||
assert result[0] == 'first document chunk'
|
||||
assert result[1] == 'second document chunk'
|
||||
assert result[0].chunk_id == 'first document chunk'
|
||||
assert result[1].chunk_id == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings returns multiple results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
|
||||
# Mock query response with multiple results
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'chunk_id': 'document from vector 1'}
|
||||
mock_point1.payload = {'chunk_id': 'document chunk 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'chunk_id': 'document from vector 2'}
|
||||
mock_point2.payload = {'chunk_id': 'document chunk 2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'chunk_id': 'another document from vector 2'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
mock_point3.payload = {'chunk_id': 'document chunk 3'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
|
||||
# Verify both collections were queried (both 2-dimensional vectors)
|
||||
# Verify collection was queried correctly
|
||||
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify results from both vectors are combined
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert 'document from vector 1' in result
|
||||
assert 'document from vector 2' in result
|
||||
assert 'another document from vector 2' in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'document chunk 1' in chunk_ids
|
||||
assert 'document chunk 2' in chunk_ids
|
||||
assert 'document chunk 3' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
|
@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
|
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
"""Test querying document embeddings with a higher dimension vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'chunk_id': 'document from 2D vector'}
|
||||
mock_point1.payload = {'chunk_id': 'document from 5D vector'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'chunk_id': 'document from 3D vector'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
mock_point2.payload = {'chunk_id': 'another 5D document'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
|
||||
# Create mock message with 5D vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once with correct collection
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
# Call should use 5D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
assert 'document from 2D vector' in result
|
||||
assert 'document from 3D vector' in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'document from 5D vector' in chunk_ids
|
||||
assert 'another 5D document' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'utf8_user'
|
||||
mock_message.collection = 'utf8_collection'
|
||||
|
|
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify UTF-8 content works correctly
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in result
|
||||
assert 'Chinese text: 你好世界' in result
|
||||
|
||||
# Verify UTF-8 content works correctly in ChunkMatch objects
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
|
||||
assert 'Chinese text: 你好世界' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
|
@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
|
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0
|
||||
|
||||
# Result should contain all returned documents
|
||||
|
||||
# Result should contain all returned documents as ChunkMatch objects
|
||||
assert len(result) == 1
|
||||
assert result[0] == 'document chunk'
|
||||
assert isinstance(result[0], ChunkMatch)
|
||||
assert result[0].chunk_id == 'document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with large limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 1000 # Large limit
|
||||
mock_message.user = 'large_user'
|
||||
mock_message.collection = 'large_collection'
|
||||
|
|
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 1000
|
||||
|
||||
# Result should contain all available documents
|
||||
|
||||
# Result should contain all available documents as ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
assert 'document 1' in result
|
||||
assert 'document 2' in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'document 1' in chunk_ids
|
||||
assert 'document 2' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -508,7 +504,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'payload_user'
|
||||
mock_message.collection = 'payload_collection'
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
|
||||
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
|
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
|
||||
)
|
||||
|
||||
# Verify results are converted to Term objects
|
||||
# Verify results are converted to EntityMatch objects
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], Term)
|
||||
assert result[0].iri == "http://example.com/entity1"
|
||||
assert result[0].type == IRI
|
||||
assert isinstance(result[1], Term)
|
||||
assert result[1].iri == "http://example.com/entity2"
|
||||
assert result[1].type == IRI
|
||||
assert isinstance(result[2], Term)
|
||||
assert result[2].value == "literal entity"
|
||||
assert result[2].type == LITERAL
|
||||
assert isinstance(result[0], EntityMatch)
|
||||
assert result[0].entity.iri == "http://example.com/entity1"
|
||||
assert result[0].entity.type == IRI
|
||||
assert isinstance(result[1], EntityMatch)
|
||||
assert result[1].entity.iri == "http://example.com/entity2"
|
||||
assert result[1].entity.type == IRI
|
||||
assert isinstance(result[2], EntityMatch)
|
||||
assert result[2].entity.value == "literal entity"
|
||||
assert result[2].entity.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
async def test_query_graph_embeddings_multiple_results(self, processor):
|
||||
"""Test querying graph embeddings returns multiple results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=3
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
|
||||
# Mock search results with multiple entities
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters including user/collection
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
|
||||
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results are deduplicated and limited
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
|
||||
)
|
||||
|
||||
# Verify results are EntityMatch objects
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
|
@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
)
|
||||
|
||||
|
|
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication(self, processor):
|
||||
"""Test that duplicate entities are properly deduplicated"""
|
||||
async def test_query_graph_embeddings_preserves_order(self, processor):
|
||||
"""Test that query results preserve order from the vector store"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}}, # New
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify duplicates are removed
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - first vector returns enough results
|
||||
mock_results_1 = [
|
||||
# Mock search results in specific order
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results_1
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify only first vector was searched (limit reached)
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
|
||||
|
||||
# Verify results are in the same order as returned by the store
|
||||
assert len(result) == 3
|
||||
assert result[0].entity.iri == "http://example.com/entity1"
|
||||
assert result[1].entity.iri == "http://example.com/entity2"
|
||||
assert result[2].entity.iri == "http://example.com/entity3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_results_limited(self, processor):
|
||||
"""Test that results are properly limited when store returns more than requested"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Verify results are limited
|
||||
|
||||
# Mock search results - returns more results than limit
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
|
||||
)
|
||||
|
||||
# Verify results are limited to requested amount
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
||||
|
||||
# Check URI entities
|
||||
uri_results = [r for r in result if r.type == IRI]
|
||||
uri_results = [r for r in result if r.entity.type == IRI]
|
||||
assert len(uri_results) == 2
|
||||
uri_values = [r.iri for r in uri_results]
|
||||
uri_values = [r.entity.iri for r in uri_results]
|
||||
assert "http://example.com/uri_entity" in uri_values
|
||||
assert "https://example.com/another_uri" in uri_values
|
||||
|
||||
|
||||
# Check literal entities
|
||||
literal_results = [r for r in result if not r.type == IRI]
|
||||
literal_results = [r for r in result if not r.entity.type == IRI]
|
||||
assert len(literal_results) == 2
|
||||
literal_values = [r.value for r in literal_results]
|
||||
literal_values = [r.entity.value for r in literal_results]
|
||||
assert "literal entity text" in literal_values
|
||||
assert "another literal" in literal_values
|
||||
|
||||
|
|
@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
|
|
@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
async def test_query_graph_embeddings_longer_vector(self, processor):
|
||||
"""Test querying graph embeddings with a longer vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
|
||||
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
|
||||
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert "entity_2d" in entity_values
|
||||
assert "entity_4d" in entity_values
|
||||
assert "entity_3d" in entity_values
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
|
|
@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
|
|||
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
|
||||
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
include_metadata=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
# Verify results use EntityMatch structure
|
||||
assert len(entities) == 3
|
||||
assert entities[0].value == 'http://example.org/entity1'
|
||||
assert entities[0].type == IRI
|
||||
assert entities[1].value == 'entity2'
|
||||
assert entities[1].type == LITERAL
|
||||
assert entities[2].value == 'http://example.org/entity3'
|
||||
assert entities[2].type == IRI
|
||||
assert entities[0].entity.iri == 'http://example.org/entity1'
|
||||
assert entities[0].entity.type == IRI
|
||||
assert entities[1].entity.value == 'entity2'
|
||||
assert entities[1].entity.type == LITERAL
|
||||
assert entities[2].entity.iri == 'http://example.org/entity3'
|
||||
assert entities[2].entity.type == IRI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
|
||||
"""Test basic graph embeddings query"""
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query results with distinct entities
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'})
|
||||
]
|
||||
|
||||
# Second query results
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
||||
# Verify deduplication occurred
|
||||
entity_values = [e.value for e in entities]
|
||||
|
||||
# Verify query was made once
|
||||
assert mock_index.query.call_count == 1
|
||||
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert len(entity_values) == 3
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
|
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions using same index"""
|
||||
async def test_query_graph_embeddings_2d_vector(self, processor):
|
||||
"""Test querying with a 2D vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.vector = [0.1, 0.2] # 2D vector
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock single index that handles all dimensions
|
||||
# Mock index
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock results for different vector queries
|
||||
mock_results_2d = MagicMock()
|
||||
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
# Mock results for 2D vector query
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
|
||||
mock_results_4d = MagicMock()
|
||||
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
|
||||
|
||||
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
index_names = [call[0][0] for call in index_calls]
|
||||
assert "t-test_user-test_collection-2" in index_names # 2D vector
|
||||
assert "t-test_user-test_collection-4" in index_names # 4D vector
|
||||
# Verify correct index used for 2D vector
|
||||
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
# Verify query was made
|
||||
assert mock_index.query.call_count == 1
|
||||
|
||||
# Verify results from both dimensions
|
||||
entity_values = [e.value for e in entities]
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert 'entity_2d' in entity_values
|
||||
assert 'entity_4d' in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.vector = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
|
||||
"""Test that deduplication works correctly across multiple vector queries"""
|
||||
async def test_query_graph_embeddings_deduplication_in_results(self, processor):
|
||||
"""Test that deduplication works correctly within query results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Both queries return overlapping results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query returns results with some duplicates
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}),
|
||||
MagicMock(metadata={'entity': 'entity4'})
|
||||
]
|
||||
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity5'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
entity_values = [e.value for e in entities]
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
async def test_query_graph_embeddings_respects_limit(self, processor):
|
||||
"""Test that query respects limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query returns enough results to meet limit
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query returns more results than limit
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results1
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Should only make one query since limit was reached
|
||||
|
||||
# Should only return 2 entities (respecting limit)
|
||||
mock_index.query.assert_called_once()
|
||||
assert len(entities) == 2
|
||||
|
||||
|
|
@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
from trustgraph.schema import IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
|
|
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
|
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected entities
|
||||
# Verify result contains expected EntityMatch objects
|
||||
assert len(result) == 2
|
||||
assert all(hasattr(entity, 'value') for entity in result)
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert all(isinstance(entity, EntityMatch) for entity in result)
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
||||
|
|
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
|
||||
# Verify both collections were queried (both 2-dimensional vectors)
|
||||
# Verify collection was queried
|
||||
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
||||
entity_values = [entity.value for entity in result]
|
||||
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
|
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
|
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.vector = [0.1, 0.2] # 2D vector
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
# Call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
entity_values = [entity.value for entity in result]
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert 'entity2d' in entity_values
|
||||
assert 'entity3d' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'uri_user'
|
||||
mock_message.collection = 'uri_collection'
|
||||
|
|
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if entity.type == IRI]
|
||||
uri_entities = [entity for entity in result if entity.entity.type == IRI]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.iri for entity in uri_entities]
|
||||
uri_values = [entity.entity.iri for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if entity.type == LITERAL]
|
||||
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
assert regular_entities[0].entity.value == 'regular entity'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
|
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
|
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# With zero limit, the logic still adds one entity before checking the limit
|
||||
# So it returns one result (current behavior, not ideal but actual)
|
||||
assert len(result) == 1
|
||||
assert result[0].value == 'entity1'
|
||||
assert result[0].entity.value == 'entity1'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue