mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -23,11 +23,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
# Create test document embeddings
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk_id="This is the first document chunk",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
)
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk_id="This is the second document chunk",
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
vector=[0.7, 0.8, 0.9]
|
||||
)
|
||||
message.chunks = [chunk1, chunk2]
|
||||
|
||||
|
|
@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify insert was called for each vector with user/collection parameters
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'),
|
||||
([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_doc
|
||||
assert actual_call[0][2] == expected_user
|
||||
assert actual_call[0][3] == expected_collection
|
||||
|
||||
# Verify insert was called once for the single chunk with its vector
|
||||
processor.vecstore.insert.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||
"""Test storing document embeddings for multiple chunks"""
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Verify insert was called for each vector of each chunk with user/collection parameters
|
||||
|
||||
# Verify insert was called once per chunk with user/collection parameters
|
||||
expected_calls = [
|
||||
# Chunk 1 vectors
|
||||
([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'),
|
||||
([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
|
||||
# Chunk 2 vectors
|
||||
# Chunk 1 - single vector
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
|
||||
# Chunk 2 - single vector
|
||||
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
|
|
@ -137,7 +127,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -156,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id=None,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -177,15 +167,15 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
valid_chunk = ChunkEmbeddings(
|
||||
chunk_id="Valid document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
empty_chunk = ChunkEmbeddings(
|
||||
chunk_id="",
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
vector=[0.4, 0.5, 0.6]
|
||||
)
|
||||
another_valid = ChunkEmbeddings(
|
||||
chunk_id="Another valid chunk",
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
vector=[0.7, 0.8, 0.9]
|
||||
)
|
||||
message.chunks = [valid_chunk, empty_chunk, another_valid]
|
||||
|
||||
|
|
@ -229,7 +219,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Document with no vectors",
|
||||
vectors=[]
|
||||
vector=[]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -245,26 +235,31 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Document with mixed dimensions",
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
|
||||
# Each chunk has a single vector of different dimensions
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk_id="chunk/doc/2d",
|
||||
vector=[0.1, 0.2] # 2D vector
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk_id="chunk/doc/4d",
|
||||
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
)
|
||||
chunk3 = ChunkEmbeddings(
|
||||
chunk_id="chunk/doc/3d",
|
||||
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||
)
|
||||
message.chunks = [chunk1, chunk2, chunk3]
|
||||
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension with user/collection parameters
|
||||
expected_calls = [
|
||||
([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'),
|
||||
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'),
|
||||
([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'),
|
||||
([0.1, 0.2], "chunk/doc/2d", 'test_user', 'test_collection'),
|
||||
([0.3, 0.4, 0.5, 0.6], "chunk/doc/4d", 'test_user', 'test_collection'),
|
||||
([0.7, 0.8, 0.9], "chunk/doc/3d", 'test_user', 'test_collection'),
|
||||
]
|
||||
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
|
|
@ -283,7 +278,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="chunk/doc/unicode-éñ中文🚀",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -306,7 +301,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
long_chunk_id = "chunk/doc/" + "a" * 200
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id=long_chunk_id,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -327,7 +322,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id=" \n\t ",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -358,7 +353,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Test content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -379,7 +374,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
message1.metadata.collection = 'collection1'
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk_id="User1 content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message1.chunks = [chunk1]
|
||||
|
||||
|
|
@ -390,7 +385,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
message2.metadata.collection = 'collection2'
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk_id="User2 content",
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
vector=[0.4, 0.5, 0.6]
|
||||
)
|
||||
message2.chunks = [chunk2]
|
||||
|
||||
|
|
@ -421,7 +416,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk_id="Special chars test",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,11 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
# Create test document embeddings
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk=b"This is the first document chunk",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
)
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk=b"This is the second document chunk",
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
vector=[0.7, 0.8, 0.9]
|
||||
)
|
||||
message.chunks = [chunk1, chunk2]
|
||||
|
||||
|
|
@ -125,7 +125,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -222,7 +222,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -244,7 +244,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=None,
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -266,7 +266,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"", # Empty bytes
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -286,37 +286,39 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with mixed dimensions",
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
# Each chunk has a single vector of different dimensions
|
||||
chunk1 = ChunkEmbeddings(
|
||||
chunk=b"Document chunk 1",
|
||||
vector=[0.1, 0.2] # 2D vector
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
mock_index_2d = MagicMock()
|
||||
mock_index_4d = MagicMock()
|
||||
mock_index_3d = MagicMock()
|
||||
|
||||
chunk2 = ChunkEmbeddings(
|
||||
chunk=b"Document chunk 2",
|
||||
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
)
|
||||
chunk3 = ChunkEmbeddings(
|
||||
chunk=b"Document chunk 3",
|
||||
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||
)
|
||||
message.chunks = [chunk1, chunk2, chunk3]
|
||||
|
||||
mock_index = MagicMock()
|
||||
|
||||
def mock_index_side_effect(name):
|
||||
# All dimensions now use the same index name pattern
|
||||
# Different dimensions will be handled within the same index
|
||||
if "test_user" in name and "test_collection" in name:
|
||||
return mock_index_2d # Just return one mock for all
|
||||
return mock_index
|
||||
return MagicMock()
|
||||
|
||||
|
||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify all vectors are now stored in the same index
|
||||
# (Pinecone can handle mixed dimensions in the same index)
|
||||
assert processor.pinecone.Index.call_count == 3 # Called once per vector
|
||||
mock_index_2d.upsert.call_count == 3 # All upserts go to same index
|
||||
# (Each chunk has a single vector, called once per chunk)
|
||||
assert processor.pinecone.Index.call_count == 3 # Called once per chunk
|
||||
assert mock_index.upsert.call_count == 3 # All upserts go to same index
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
||||
|
|
@ -346,7 +348,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Document with no vectors",
|
||||
vectors=[]
|
||||
vector=[]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -368,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -393,7 +395,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -419,7 +421,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
@ -447,7 +449,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
large_content = "A" * 10000 # 10KB of content
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=large_content.encode('utf-8'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
mock_chunk.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
|
|
@ -143,11 +143,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk_id = 'doc/c1'
|
||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||
mock_chunk1.vector = [0.1, 0.2]
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk_id = 'doc/c2'
|
||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||
mock_chunk2.vector = [0.3, 0.4]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
|
|
@ -175,8 +175,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple vectors per chunk"""
|
||||
async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple chunks"""
|
||||
# Arrange
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
|
|
@ -196,41 +196,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||
|
||||
# Create mock message with chunk having multiple vectors
|
||||
# Create mock message with multiple chunks, each having a single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk_id = 'doc/multi-vector'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk_id = 'doc/c1'
|
||||
mock_chunk1.vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk_id = 'doc/c2'
|
||||
mock_chunk2.vector = [0.4, 0.5, 0.6]
|
||||
|
||||
mock_chunk3 = MagicMock()
|
||||
mock_chunk3.chunk_id = 'doc/c3'
|
||||
mock_chunk3.vector = [0.7, 0.8, 0.9]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
# Should be called 3 times (once per chunk)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
expected_data = [
|
||||
([0.1, 0.2, 0.3], 'doc/c1'),
|
||||
([0.4, 0.5, 0.6], 'doc/c2'),
|
||||
([0.7, 0.8, 0.9], 'doc/c3')
|
||||
]
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['chunk_id'] == 'doc/multi-vector'
|
||||
assert point.vector == expected_data[i][0]
|
||||
assert point.payload['chunk_id'] == expected_data[i][1]
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
|
||||
|
|
@ -256,7 +260,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk_empty = MagicMock()
|
||||
mock_chunk_empty.chunk_id = "" # Empty chunk_id
|
||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
||||
mock_chunk_empty.vector = [0.1, 0.2]
|
||||
|
||||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
|
|
@ -299,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk_id = 'doc/test-chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
||||
mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
|
|
@ -351,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk_id = 'doc/test-chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
mock_chunk.vector = [0.1, 0.2]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
|
|
@ -389,7 +393,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk_id = 'doc/c1'
|
||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_chunk1.vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
|
|
@ -407,7 +411,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk_id = 'doc/c2'
|
||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
||||
mock_chunk2.vector = [0.4, 0.5, 0.6] # Same dimension (3)
|
||||
|
||||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
|
|
@ -446,19 +450,20 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('dim_user', 'dim_collection')] = {}
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
# Create mock message with chunks of different dimensions
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
mock_message.metadata.collection = 'dim_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk_id = 'doc/dim-test'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2], # 2 dimensions
|
||||
[0.3, 0.4, 0.5] # 3 dimensions
|
||||
]
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk_id = 'doc/c1'
|
||||
mock_chunk1.vector = [0.1, 0.2] # 2 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk_id = 'doc/c2'
|
||||
mock_chunk2.vector = [0.3, 0.4, 0.5] # 3 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
|
@ -526,7 +531,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
mock_chunk.vector = [0.1, 0.2]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
# Create test entities with embeddings
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/entity1'),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Term(type=LITERAL, value='literal entity'),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
vector=[0.7, 0.8, 0.9]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
|
||||
|
|
@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify insert was called for each vector with user/collection parameters
|
||||
expected_calls = [
|
||||
([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'),
|
||||
([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
assert actual_call[0][1] == expected_entity
|
||||
assert actual_call[0][2] == expected_user
|
||||
assert actual_call[0][3] == expected_collection
|
||||
|
||||
# Verify insert was called once with the full vector
|
||||
processor.vecstore.insert.assert_called_once()
|
||||
actual_call = processor.vecstore.insert.call_args_list[0]
|
||||
assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
assert actual_call[0][1] == 'http://example.com/entity'
|
||||
assert actual_call[0][2] == 'test_user'
|
||||
assert actual_call[0][3] == 'test_collection'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||
"""Test storing graph embeddings for multiple entities"""
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Verify insert was called for each vector of each entity with user/collection parameters
|
||||
|
||||
# Verify insert was called once per entity with user/collection parameters
|
||||
expected_calls = [
|
||||
# Entity 1 vectors
|
||||
([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
||||
([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
||||
# Entity 2 vectors
|
||||
# Entity 1 - single vector
|
||||
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
||||
# Entity 2 - single vector
|
||||
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
|
||||
]
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
|
||||
assert processor.vecstore.insert.call_count == 2
|
||||
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
assert actual_call[0][0] == expected_vec
|
||||
|
|
@ -137,7 +130,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Term(type=LITERAL, value=''),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -156,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Term(type=LITERAL, value=None),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -175,17 +168,17 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
|
||||
valid_entity = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/valid'),
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
chunk_id=''
|
||||
)
|
||||
empty_entity = EntityEmbeddings(
|
||||
entity=Term(type=LITERAL, value=''),
|
||||
vectors=[[0.4, 0.5, 0.6]],
|
||||
vector=[0.4, 0.5, 0.6],
|
||||
chunk_id=''
|
||||
)
|
||||
none_entity = EntityEmbeddings(
|
||||
entity=Term(type=LITERAL, value=None),
|
||||
vectors=[[0.7, 0.8, 0.9]],
|
||||
vector=[0.7, 0.8, 0.9],
|
||||
chunk_id=''
|
||||
)
|
||||
message.entities = [valid_entity, empty_entity, none_entity]
|
||||
|
|
@ -222,7 +215,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[]
|
||||
vector=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -238,26 +231,31 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
|
||||
# Each entity has a single vector of different dimensions
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/entity1'),
|
||||
vector=[0.1, 0.2] # 2D vector
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/entity2'),
|
||||
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
)
|
||||
entity3 = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/entity3'),
|
||||
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||
)
|
||||
message.entities = [entity1, entity2, entity3]
|
||||
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
|
||||
# Verify all vectors were inserted regardless of dimension
|
||||
expected_calls = [
|
||||
([0.1, 0.2], 'http://example.com/entity'),
|
||||
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
|
||||
([0.7, 0.8, 0.9], 'http://example.com/entity'),
|
||||
([0.1, 0.2], 'http://example.com/entity1'),
|
||||
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity2'),
|
||||
([0.7, 0.8, 0.9], 'http://example.com/entity3'),
|
||||
]
|
||||
|
||||
|
||||
assert processor.vecstore.insert.call_count == 3
|
||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||
|
|
@ -274,11 +272,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
|
||||
uri_entity = EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
literal_entity = EntityEmbeddings(
|
||||
entity=Term(type=LITERAL, value='literal entity text'),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
vector=[0.4, 0.5, 0.6]
|
||||
)
|
||||
message.entities = [uri_entity, literal_entity]
|
||||
|
||||
|
|
|
|||
|
|
@ -24,16 +24,20 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
# Create test entity embeddings
|
||||
# Create test entity embeddings (each entity has a single vector)
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value="entity2", is_uri=False),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
entity=Value(value="http://example.org/entity2", is_uri=True),
|
||||
vector=[0.4, 0.5, 0.6]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
entity3 = EntityEmbeddings(
|
||||
entity=Value(value="entity3", is_uri=False),
|
||||
vector=[0.7, 0.8, 0.9]
|
||||
)
|
||||
message.entities = [entity1, entity2, entity3]
|
||||
|
||||
return message
|
||||
|
||||
|
|
@ -122,27 +126,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
||||
# Mock index operations
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
||||
|
||||
with patch('uuid.uuid4', side_effect=['id1']):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
|
||||
# Verify index name and operations (with dimension suffix)
|
||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||
|
||||
# Verify upsert was called for each vector
|
||||
assert mock_index.upsert.call_count == 2
|
||||
|
||||
# Verify upsert was called for the single vector
|
||||
assert mock_index.upsert.call_count == 1
|
||||
|
||||
# Check first vector upsert
|
||||
first_call = mock_index.upsert.call_args_list[0]
|
||||
|
|
@ -190,7 +194,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -222,7 +226,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -244,7 +248,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -258,23 +262,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test storing graph embeddings with different vector dimensions to same index"""
|
||||
"""Test storing graph embeddings with different vector dimensions"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
]
|
||||
# Each entity has a single vector of different dimensions
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value="entity1", is_uri=False),
|
||||
vector=[0.1, 0.2] # 2D vector
|
||||
)
|
||||
message.entities = [entity]
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value="entity2", is_uri=False),
|
||||
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
)
|
||||
entity3 = EntityEmbeddings(
|
||||
entity=Value(value="entity3", is_uri=False),
|
||||
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||
)
|
||||
message.entities = [entity1, entity2, entity3]
|
||||
|
||||
# All vectors now use the same index (no dimension in name)
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
processor.pinecone.has_index.return_value = True
|
||||
|
|
@ -322,7 +330,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[]
|
||||
vector=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -344,7 +352,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
@ -369,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_entity = MagicMock()
|
||||
mock_entity.entity.type = IRI
|
||||
mock_entity.entity.iri = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
mock_entity.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
|
|
@ -124,12 +124,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.type = IRI
|
||||
mock_entity1.entity.iri = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
mock_entity1.vector = [0.1, 0.2]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.type = IRI
|
||||
mock_entity2.entity.iri = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
mock_entity2.vector = [0.3, 0.4]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
||||
|
|
@ -157,14 +157,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple vectors per entity"""
|
||||
async def test_store_graph_embeddings_three_entities(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with three entities"""
|
||||
# Arrange
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -177,42 +177,48 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||
|
||||
# Create mock message with entity having multiple vectors
|
||||
# Create mock message with three entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.type = IRI
|
||||
mock_entity.entity.iri = 'multi_vector_entity'
|
||||
mock_entity.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.type = IRI
|
||||
mock_entity1.entity.iri = 'entity_one'
|
||||
mock_entity1.vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.type = IRI
|
||||
mock_entity2.entity.iri = 'entity_two'
|
||||
mock_entity2.vector = [0.4, 0.5, 0.6]
|
||||
|
||||
mock_entity3 = MagicMock()
|
||||
mock_entity3.entity.type = IRI
|
||||
mock_entity3.entity.iri = 'entity_three'
|
||||
mock_entity3.vector = [0.7, 0.8, 0.9]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
# Should be called 3 times (once per entity)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
|
||||
# Verify all entities were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
|
||||
expected_data = [
|
||||
([0.1, 0.2, 0.3], 'entity_one'),
|
||||
([0.4, 0.5, 0.6], 'entity_two'),
|
||||
([0.7, 0.8, 0.9], 'entity_three')
|
||||
]
|
||||
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['entity'] == 'multi_vector_entity'
|
||||
assert point.vector == expected_data[i][0]
|
||||
assert point.payload['entity'] == expected_data[i][1]
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
|
||||
|
|
@ -238,11 +244,11 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_entity_empty = MagicMock()
|
||||
mock_entity_empty.entity.type = LITERAL
|
||||
mock_entity_empty.entity.value = "" # Empty string
|
||||
mock_entity_empty.vectors = [[0.1, 0.2]]
|
||||
mock_entity_empty.vector = [0.1, 0.2]
|
||||
|
||||
mock_entity_none = MagicMock()
|
||||
mock_entity_none.entity = None # None entity
|
||||
mock_entity_none.vectors = [[0.3, 0.4]]
|
||||
mock_entity_none.vector = [0.3, 0.4]
|
||||
|
||||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
index_name='customer_id',
|
||||
index_value=['CUST001'],
|
||||
text='CUST001',
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
|
|
@ -227,8 +227,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing embeddings with multiple vectors"""
|
||||
async def test_on_embeddings_single_vector(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing embeddings with a single vector"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
|
|
@ -250,12 +250,12 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
# Embedding with multiple vectors
|
||||
# Embedding with a single 6D vector
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='name',
|
||||
index_value=['John Doe'],
|
||||
text='John Doe',
|
||||
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
|
|
@ -269,8 +269,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
# Should be called once for the single embedding
|
||||
assert mock_qdrant_instance.upsert.call_count == 1
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
|
||||
|
|
@ -299,7 +299,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[] # Empty vectors
|
||||
vector=[] # Empty vector
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
|
|
@ -342,7 +342,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[[0.1, 0.2]]
|
||||
vector=[0.1, 0.2]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue