mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-03 20:05:13 +02:00
Fix hard coded vector size (#555)
* Fixed hard-coded embeddings store size * Vector store lazy-creates collections, different collections for different dimension lengths. * Added tech spec for vector store lifecycle * Fixed some tests for the new spec
This commit is contained in:
parent
05b9063fea
commit
6129bb68c1
22 changed files with 793 additions and 572 deletions
|
|
@ -30,14 +30,16 @@ class TestMilvusUserCollectionIntegration:
|
|||
|
||||
for user, collection, vector in test_cases:
|
||||
doc_vectors.insert(vector, "test document", user, collection)
|
||||
|
||||
|
||||
expected_collection_name = make_safe_collection_name(
|
||||
user, collection, "doc"
|
||||
)
|
||||
|
||||
# Verify collection was created with correct name
|
||||
# Add dimension suffix to expected name
|
||||
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
|
||||
|
||||
# Verify collection was created with correct name (including dimension)
|
||||
assert (len(vector), user, collection) in doc_vectors.collections
|
||||
assert doc_vectors.collections[(len(vector), user, collection)] == expected_collection_name
|
||||
assert doc_vectors.collections[(len(vector), user, collection)] == expected_collection_name_with_dim
|
||||
|
||||
@patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient')
|
||||
def test_entity_vectors_collection_creation_with_user_collection(self, mock_milvus_client):
|
||||
|
|
@ -56,14 +58,16 @@ class TestMilvusUserCollectionIntegration:
|
|||
|
||||
for user, collection, vector in test_cases:
|
||||
entity_vectors.insert(vector, "test entity", user, collection)
|
||||
|
||||
|
||||
expected_collection_name = make_safe_collection_name(
|
||||
user, collection, "entity"
|
||||
)
|
||||
|
||||
# Verify collection was created with correct name
|
||||
# Add dimension suffix to expected name
|
||||
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
|
||||
|
||||
# Verify collection was created with correct name (including dimension)
|
||||
assert (len(vector), user, collection) in entity_vectors.collections
|
||||
assert entity_vectors.collections[(len(vector), user, collection)] == expected_collection_name
|
||||
assert entity_vectors.collections[(len(vector), user, collection)] == expected_collection_name_with_dim
|
||||
|
||||
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
|
||||
def test_doc_vectors_search_uses_correct_collection(self, mock_milvus_client):
|
||||
|
|
@ -88,11 +92,12 @@ class TestMilvusUserCollectionIntegration:
|
|||
# Now search
|
||||
result = doc_vectors.search(vector, user, collection, limit=5)
|
||||
|
||||
# Verify search was called with correct collection name
|
||||
# Verify search was called with correct collection name (including dimension)
|
||||
expected_collection_name = make_safe_collection_name(user, collection, "doc")
|
||||
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
|
||||
mock_client.search.assert_called_once()
|
||||
search_call = mock_client.search.call_args
|
||||
assert search_call[1]["collection_name"] == expected_collection_name
|
||||
assert search_call[1]["collection_name"] == expected_collection_name_with_dim
|
||||
|
||||
@patch('trustgraph.direct.milvus_graph_embeddings.MilvusClient')
|
||||
def test_entity_vectors_search_uses_correct_collection(self, mock_milvus_client):
|
||||
|
|
@ -117,11 +122,12 @@ class TestMilvusUserCollectionIntegration:
|
|||
# Now search
|
||||
result = entity_vectors.search(vector, user, collection, limit=5)
|
||||
|
||||
# Verify search was called with correct collection name
|
||||
# Verify search was called with correct collection name (including dimension)
|
||||
expected_collection_name = make_safe_collection_name(user, collection, "entity")
|
||||
expected_collection_name_with_dim = f"{expected_collection_name}_{len(vector)}"
|
||||
mock_client.search.assert_called_once()
|
||||
search_call = mock_client.search.call_args
|
||||
assert search_call[1]["collection_name"] == expected_collection_name
|
||||
assert search_call[1]["collection_name"] == expected_collection_name_with_dim
|
||||
|
||||
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
|
||||
def test_doc_vectors_collection_isolation(self, mock_milvus_client):
|
||||
|
|
@ -141,10 +147,11 @@ class TestMilvusUserCollectionIntegration:
|
|||
assert len(doc_vectors.collections) == 3
|
||||
|
||||
collection_names = set(doc_vectors.collections.values())
|
||||
# All vectors are 3-dimensional, so all names should have _3 suffix
|
||||
expected_names = {
|
||||
"doc_user1_collection1",
|
||||
"doc_user2_collection2",
|
||||
"doc_user1_collection2"
|
||||
"doc_user1_collection1_3",
|
||||
"doc_user2_collection2_3",
|
||||
"doc_user1_collection2_3"
|
||||
}
|
||||
assert collection_names == expected_names
|
||||
|
||||
|
|
@ -166,10 +173,11 @@ class TestMilvusUserCollectionIntegration:
|
|||
assert len(entity_vectors.collections) == 3
|
||||
|
||||
collection_names = set(entity_vectors.collections.values())
|
||||
# All vectors are 3-dimensional, so all names should have _3 suffix
|
||||
expected_names = {
|
||||
"entity_user1_collection1",
|
||||
"entity_user2_collection2",
|
||||
"entity_user1_collection2"
|
||||
"entity_user1_collection1_3",
|
||||
"entity_user2_collection2_3",
|
||||
"entity_user1_collection2_3"
|
||||
}
|
||||
assert collection_names == expected_names
|
||||
|
||||
|
|
@ -191,16 +199,16 @@ class TestMilvusUserCollectionIntegration:
|
|||
|
||||
# Verify three separate collections were created for different dimensions
|
||||
assert len(doc_vectors.collections) == 3
|
||||
|
||||
|
||||
collection_names = set(doc_vectors.collections.values())
|
||||
# Different dimensions now create different collections with dimension suffixes
|
||||
expected_names = {
|
||||
"doc_test_user_test_collection", # Same name for all dimensions
|
||||
"doc_test_user_test_collection", # now stored per dimension in key
|
||||
"doc_test_user_test_collection" # but collection name is the same
|
||||
"doc_test_user_test_collection_2", # 2D vector
|
||||
"doc_test_user_test_collection_3", # 3D vector
|
||||
"doc_test_user_test_collection_4" # 4D vector
|
||||
}
|
||||
# Note: Now all dimensions use the same collection name, they are differentiated by the key
|
||||
assert len(collection_names) == 1 # Only one unique collection name
|
||||
assert "doc_test_user_test_collection" in collection_names
|
||||
# Each dimension gets its own collection
|
||||
assert len(collection_names) == 3 # Three unique collection names
|
||||
assert collection_names == expected_names
|
||||
|
||||
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
|
||||
|
|
@ -222,8 +230,9 @@ class TestMilvusUserCollectionIntegration:
|
|||
|
||||
# Verify only one collection was created
|
||||
assert len(doc_vectors.collections) == 1
|
||||
|
||||
expected_collection_name = "doc_test_user_test_collection"
|
||||
|
||||
# Collection name now includes dimension suffix
|
||||
expected_collection_name = "doc_test_user_test_collection_3"
|
||||
assert doc_vectors.collections[(3, user, collection)] == expected_collection_name
|
||||
|
||||
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
|
||||
|
|
@ -235,19 +244,20 @@ class TestMilvusUserCollectionIntegration:
|
|||
doc_vectors = DocVectors(uri="http://test:19530", prefix="doc")
|
||||
|
||||
# Test various special character combinations
|
||||
# All expected names now include dimension suffix _3
|
||||
test_cases = [
|
||||
("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1"),
|
||||
("user_123", "collection_456", "doc_user_123_collection_456"),
|
||||
("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces"),
|
||||
("user@@@test", "collection---test", "doc_user_test_collection_test"),
|
||||
("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1_3"),
|
||||
("user_123", "collection_456", "doc_user_123_collection_456_3"),
|
||||
("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces_3"),
|
||||
("user@@@test", "collection---test", "doc_user_test_collection_test_3"),
|
||||
]
|
||||
|
||||
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
for user, collection, expected_name in test_cases:
|
||||
doc_vectors_instance = DocVectors(uri="http://test:19530", prefix="doc")
|
||||
doc_vectors_instance.insert(vector, "test doc", user, collection)
|
||||
|
||||
|
||||
assert doc_vectors_instance.collections[(3, user, collection)] == expected_name
|
||||
|
||||
def test_collection_name_backward_compatibility(self):
|
||||
|
|
|
|||
|
|
@ -119,8 +119,8 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify index was accessed correctly
|
||||
expected_index_name = "d-test_user-test_collection"
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
processor.pinecone.Index.assert_called_once_with(expected_index_name)
|
||||
|
||||
# Verify query parameters
|
||||
|
|
@ -264,10 +264,12 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
chunks = await processor.query_document_embeddings(message)
|
||||
|
||||
# Verify same index used for both vectors
|
||||
expected_index_name = "d-test_user-test_collection"
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
index_names = [call[0][0] for call in index_calls]
|
||||
assert "d-test_user-test_collection-2" in index_names # 2D vector
|
||||
assert "d-test_user-test_collection-4" in index_names # 4D vector
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
|
|
|||
|
|
@ -103,8 +103,8 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 'd_test_user_test_collection'
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
expected_collection = 'd_test_user_test_collection_3' # 3 dimensions
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
|
|
@ -164,9 +164,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 'd_multi_user_multi_collection'
|
||||
|
||||
# Verify both collections were queried (both 2-dimensional vectors)
|
||||
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
|
|
@ -301,13 +301,13 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection'
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection'
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
|
|
|
|||
|
|
@ -147,8 +147,8 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify index was accessed correctly
|
||||
expected_index_name = "t-test_user-test_collection"
|
||||
# Verify index was accessed correctly (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
processor.pinecone.Index.assert_called_once_with(expected_index_name)
|
||||
|
||||
# Verify query parameters
|
||||
|
|
@ -290,10 +290,12 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify same index used for both vectors
|
||||
expected_index_name = "t-test_user-test_collection"
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
index_names = [call[0][0] for call in index_calls]
|
||||
assert "t-test_user-test_collection-2" in index_names # 2D vector
|
||||
assert "t-test_user-test_collection-4" in index_names # 4D vector
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
|
|
|||
|
|
@ -175,8 +175,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 't_test_user_test_collection'
|
||||
# Verify query was called with correct parameters (with dimension suffix)
|
||||
expected_collection = 't_test_user_test_collection_3' # 3 dimensions
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
|
|
@ -234,9 +234,9 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 't_multi_user_multi_collection'
|
||||
|
||||
# Verify both collections were queried (both 2-dimensional vectors)
|
||||
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
|
|
@ -372,13 +372,13 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection'
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection'
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
|
|
|
|||
|
|
@ -134,8 +134,8 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index name and operations
|
||||
expected_index_name = "d-test_user-test_collection"
|
||||
# Verify index name and operations (with dimension suffix)
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
|
||||
# Verify upsert was called for each vector
|
||||
|
|
@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_validation(self, processor):
|
||||
"""Test that writing to non-existent index raises ValueError"""
|
||||
"""Test that writing to non-existent index creates it lazily"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
|
|
@ -191,12 +191,24 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist
|
||||
# Mock index doesn't exist initially
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index was created with correct dimension
|
||||
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
|
||||
# Verify upsert was still called
|
||||
mock_index.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunk(self, processor):
|
||||
"""Test storing document embeddings with empty chunk (should be skipped)"""
|
||||
|
|
@ -345,7 +357,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_validation_before_creation(self, processor):
|
||||
"""Test that validation error occurs before creation attempts"""
|
||||
"""Test that lazy creation happens when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
|
|
@ -359,13 +371,18 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index was created
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_validates_before_timeout(self, processor):
|
||||
"""Test that validation error occurs before timeout checks"""
|
||||
"""Test that lazy creation works correctly"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
|
|
@ -379,10 +396,16 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index was created and used
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
mock_index.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||
"""Test storing document embeddings with Unicode content"""
|
||||
|
|
|
|||
|
|
@ -103,8 +103,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 'd_test_user_test_collection'
|
||||
# Verify collection existence was checked (with dimension suffix)
|
||||
expected_collection = 'd_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3]
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# Verify upsert was called
|
||||
|
|
@ -112,7 +112,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert upsert_call_args[1]['collection_name'] == 'd_test_user_test_collection_3'
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
|
|
@ -272,18 +272,21 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
# Should not call upsert for empty chunks
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
# But collection_exists should be called for validation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
# collection_exists should NOT be called since we return early for empty chunks
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that writing to non-existent collection raises ValueError"""
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test that writing to non-existent collection creates it lazily"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
|
|
@ -305,19 +308,36 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert - collection should be lazily created
|
||||
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call[1]['collection_name'] == expected_collection
|
||||
assert create_call[1]['vectors_config'].size == 5
|
||||
|
||||
# Verify upsert was still called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that validation error occurs before connection errors"""
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test that collection creation errors are propagated"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
# Simulate creation failure
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Connection error")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
|
|
@ -339,8 +359,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
# Act & Assert - should propagate the creation error
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
|
|
@ -398,7 +418,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
await processor.store_document_embeddings(mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection'
|
||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||
|
||||
# Verify collection existence is checked on each write
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
|
@ -407,15 +427,18 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_different_dimensions_different_collections(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_different_dimensions_different_collections(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test that different vector dimensions create different collections"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -424,35 +447,39 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
mock_message.metadata.collection = 'dim_collection'
|
||||
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2], # 2 dimensions
|
||||
[0.3, 0.4, 0.5] # 3 dimensions
|
||||
]
|
||||
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should check existence of the same collection (dimensions no longer create separate collections)
|
||||
expected_collection = 'd_dim_user_dim_collection'
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
# Should check existence of DIFFERENT collections for each dimension
|
||||
assert mock_qdrant_instance.collection_exists.call_count == 2
|
||||
|
||||
# Should upsert to the same collection for both vectors
|
||||
# Verify the two different collection names were checked
|
||||
collection_exists_calls = [call[0][0] for call in mock_qdrant_instance.collection_exists.call_args_list]
|
||||
assert 'd_dim_user_dim_collection_2' in collection_exists_calls # 2-dim vector
|
||||
assert 'd_dim_user_dim_collection_3' in collection_exists_calls # 3-dim vector
|
||||
|
||||
# Should upsert to different collections for each vector
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
assert upsert_calls[0][1]['collection_name'] == expected_collection
|
||||
assert upsert_calls[1][1]['collection_name'] == expected_collection
|
||||
assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
|
|
|
|||
|
|
@ -134,8 +134,8 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index name and operations
|
||||
expected_index_name = "t-test_user-test_collection"
|
||||
# Verify index name and operations (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
|
||||
# Verify upsert was called for each vector
|
||||
|
|
@ -179,7 +179,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_validation(self, processor):
|
||||
"""Test that writing to non-existent index raises ValueError"""
|
||||
"""Test that writing to non-existent index creates it lazily"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
|
|
@ -193,10 +193,22 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index was created with correct dimension
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
|
||||
# Verify upsert was still called
|
||||
mock_index.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, processor):
|
||||
"""Test storing graph embeddings with empty entity value (should be skipped)"""
|
||||
|
|
@ -267,11 +279,16 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify same index was used for all dimensions
|
||||
expected_index_name = 't-test_user-test_collection'
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
# Verify different indexes were used for different dimensions
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
assert len(index_calls) == 3
|
||||
# Extract index names from calls
|
||||
index_names = [call[0][0] for call in index_calls]
|
||||
assert 't-test_user-test_collection-2' in index_names # 2D vector
|
||||
assert 't-test_user-test_collection-4' in index_names # 4D vector
|
||||
assert 't-test_user-test_collection-3' in index_names # 3D vector
|
||||
|
||||
# Verify all vectors were upserted to the same index
|
||||
# Verify all vectors were upserted (to their respective indexes)
|
||||
assert mock_index.upsert.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -316,7 +333,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_validation_before_creation(self, processor):
|
||||
"""Test that validation error occurs before any creation attempts"""
|
||||
"""Test that lazy creation happens when index doesn't exist"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
|
|
@ -330,13 +347,18 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index was created
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_validates_before_timeout(self, processor):
|
||||
"""Test that validation error occurs before timeout checks"""
|
||||
"""Test that lazy creation works correctly"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
|
|
@ -350,10 +372,16 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index was created and used
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
mock_index.upsert.assert_called_once()
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
|
|
|||
|
|
@ -44,29 +44,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_validates_existence(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection validates that collection exists"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
processor.get_collection(user='test_user', collection='test_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
|
|
@ -103,114 +80,22 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 't_test_user_test_collection'
|
||||
# Verify collection existence was checked (with dimension suffix)
|
||||
expected_collection = 't_test_user_test_collection_3' # 3 dimensions in vector [0.1, 0.2, 0.3]
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert upsert_call_args[1]['collection_name'] == 't_test_user_test_collection_3'
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['entity'] == 'test_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_uses_existing_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection uses existing collection without creating new one"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(user='existing_user', collection='existing_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_existing_user_existing_collection'
|
||||
assert collection_name == expected_name
|
||||
|
||||
# Verify collection existence check was performed
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
# Verify create_collection was NOT called
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_validates_on_each_call(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection validates collection existence on each call"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# First call
|
||||
collection_name1 = processor.get_collection(user='cache_user', collection='cache_collection')
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
|
||||
# Act - Second call with same parameters
|
||||
collection_name2 = processor.get_collection(user='cache_user', collection='cache_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_cache_user_cache_collection'
|
||||
assert collection_name1 == expected_name
|
||||
assert collection_name2 == expected_name
|
||||
|
||||
# Verify collection existence check happens on each call
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection raises ValueError when collection doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
processor.get_collection(user='error_user', collection='error_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue