Embeddings API scores (#671)

- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
cybermaggedon 2026-03-09 10:53:44 +00:00 committed by GitHub
parent 4fa7cc7d7c
commit f2ae0e8623
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 1339 additions and 1292 deletions

View file

@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
# Mock the request method
client.request = AsyncMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = await client.query(
vectors=vectors,
vector=vector,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vectors == vectors
assert call_args.vector == vector
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(RuntimeError, match="Database connection failed"):
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = []
mock_response.chunks = []
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
assert result == []
@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["test_chunk"]
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
client.request.assert_called_once()
@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["chunk1"]
mock_response.chunks = ["chunk1"]
client.request = AsyncMock(return_value=mock_response)
# Act
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
timeout=60
)
# Assert
assert client.request.call_args[1]["timeout"] == 60
@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["test_chunk"]
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
mock_logger.debug.assert_called_once()
assert "Document embeddings response" in str(mock_logger.debug.call_args)

View file

@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = client.request(
vectors=vectors,
vector=vector,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vectors=vectors,
vector=vector,
limit=10,
timeout=300
)
@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["test_chunk"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3]]
vector = [0.1, 0.2, 0.3]
# Act
result = client.request(vectors=vectors)
result = client.request(vector=vector)
# Assert
assert result == ["test_chunk"]
client.call.assert_called_once_with(
user="trustgraph",
collection="default",
vectors=vectors,
vector=vector,
limit=10,
timeout=300
)
@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = []
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
result = client.request(vector=[0.1, 0.2, 0.3])
# Assert
assert result == []
@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = None
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
result = client.request(vector=[0.1, 0.2, 0.3])
# Assert
assert result is None
@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["chunk1"]
client.call = MagicMock(return_value=mock_response)
# Act
client.request(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
timeout=600
)
# Assert
assert client.call.call_args[1]["timeout"] == 600

View file

@ -98,7 +98,7 @@ def sample_graph_embeddings():
entities=[
EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/john"),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
]
)

View file

@ -108,7 +108,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
# Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')

View file

@ -60,7 +60,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
model="test-model",
input=["test text"]
)
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -86,7 +86,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
model="custom-model",
input=["test text"]
)
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.milvus.service import Processor
from trustgraph.schema import DocumentEmbeddingsRequest
from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
class TestMilvusDocEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
)
return query
@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
)
# Verify results are document chunks
# Verify results are ChunkMatch objects
assert len(result) == 3
assert result[0] == "First document chunk"
assert result[1] == "Second document chunk"
assert result[2] == "Third document chunk"
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == "First document chunk"
assert result[1].chunk_id == "Second document chunk"
assert result[2].chunk_id == "Third document chunk"
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor):
"""Test querying document embeddings with multiple vectors"""
async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3
)
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"chunk_id": "Document from first vector"}},
{"entity": {"chunk_id": "Another doc from first vector"}},
# Mock search results
mock_results = [
{"entity": {"chunk_id": "First document"}},
{"entity": {"chunk_id": "Second document"}},
{"entity": {"chunk_id": "Third document"}},
]
mock_results_2 = [
{"entity": {"chunk_id": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results from all vectors are combined
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
)
# Verify results are ChunkMatch objects
assert len(result) == 3
assert "Document from first vector" in result
assert "Another doc from first vector" in result
assert "Document from second vector" in result
chunk_ids = [r.chunk_id for r in result]
assert "First document" in chunk_ids
assert "Second document" in chunk_ids
assert "Third document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor):
@ -141,7 +135,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=2
)
@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
vector=[],
limit=5
)
@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -211,7 +205,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -225,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify Unicode content is preserved
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with Unicode: éñ中文🚀" in result
assert "Regular ASCII document" in result
assert "Document with émojis: 😀🎉" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with Unicode: éñ中文🚀" in chunk_ids
assert "Regular ASCII document" in chunk_ids
assert "Document with émojis: 😀🎉" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_large_documents(self, processor):
@ -237,7 +232,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -251,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify large content is preserved
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
assert large_doc in result
assert "Small document" in result
chunk_ids = [r.chunk_id for r in result]
assert large_doc in chunk_ids
assert "Small document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_special_characters(self, processor):
@ -262,7 +258,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -276,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify special characters are preserved
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with \"quotes\" and 'apostrophes'" in result
assert "Document with\nnewlines\tand\ttabs" in result
assert "Document with special chars: @#$%^&*()" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
assert "Document with\nnewlines\tand\ttabs" in chunk_ids
assert "Document with special chars: @#$%^&*()" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=0
)
@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=-1
)
@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
# Mock search results
mock_results = [
{"entity": {"chunk_id": "Document 1"}},
{"entity": {"chunk_id": "Document 2"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
assert "Document from 2D vector" in result
assert "Document from 4D vector" in result
assert "Document from 3D vector" in result
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
# Verify results are ChunkMatch objects
assert len(result) == 2
chunk_ids = [r.chunk_id for r in result]
assert "Document 1" in chunk_ids
assert "Document 2" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_duplicate_documents(self, processor):
"""Test querying document embeddings with duplicate documents in results"""
async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results with duplicates across vectors
mock_results_1 = [
# Mock search results with multiple documents
mock_results = [
{"entity": {"chunk_id": "Document A"}},
{"entity": {"chunk_id": "Document B"}},
]
mock_results_2 = [
{"entity": {"chunk_id": "Document B"}}, # Duplicate
{"entity": {"chunk_id": "Document C"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
# This preserves ranking and allows multiple occurrences
assert len(result) == 4
assert result.count("Document B") == 2 # Should appear twice
assert "Document A" in result
assert "Document C" in result
# Verify results are ChunkMatch objects
assert len(result) == 3
chunk_ids = [r.chunk_id for r in result]
assert "Document A" in chunk_ids
assert "Document B" in chunk_ids
assert "Document C" in chunk_ids
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -103,7 +103,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -208,7 +208,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
@ -226,7 +226,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
@ -285,7 +285,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.vector = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -304,7 +304,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -325,7 +325,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_unicode_content(self, processor):
"""Test querying document embeddings with Unicode content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -351,7 +351,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_large_content(self, processor):
"""Test querying document embeddings with large content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 1
message.user = 'test_user'
message.collection = 'test_collection'
@ -377,7 +377,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_mixed_content_types(self, processor):
"""Test querying document embeddings with mixed content types"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -409,7 +409,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -425,7 +425,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_index_access_failure(self, processor):
"""Test handling of index access failure"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.doc_embeddings.qdrant.service import Processor
from trustgraph.schema import ChunkMatch
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected documents
# Verify result contains expected ChunkMatch objects
assert len(result) == 2
# Results should be strings (document chunks)
assert isinstance(result[0], str)
assert isinstance(result[1], str)
# Results should be ChunkMatch objects
assert isinstance(result[0], ChunkMatch)
assert isinstance(result[1], ChunkMatch)
# Verify content
assert result[0] == 'first document chunk'
assert result[1] == 'second document chunk'
assert result[0].chunk_id == 'first document chunk'
assert result[1].chunk_id == 'second document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with multiple vectors"""
async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings returns multiple results"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses for different vectors
# Mock query response with multiple results
mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from vector 1'}
mock_point1.payload = {'chunk_id': 'document chunk 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from vector 2'}
mock_point2.payload = {'chunk_id': 'document chunk 2'}
mock_point3 = MagicMock()
mock_point3.payload = {'chunk_id': 'another document from vector 2'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2, mock_point3]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point3.payload = {'chunk_id': 'document chunk 3'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2, mock_point3]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried correctly
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify results from both vectors are combined
# Verify results are ChunkMatch objects
assert len(result) == 3
assert 'document from vector 1' in result
assert 'document from vector 2' in result
assert 'another document from vector 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document chunk 1' in chunk_ids
assert 'document chunk 2' in chunk_ids
assert 'document chunk 3' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with different vector dimensions"""
"""Test querying document embeddings with a higher dimension vector"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from 2D vector'}
mock_point1.payload = {'chunk_id': 'document from 5D vector'}
mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from 3D vector'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point2.payload = {'chunk_id': 'another 5D document'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with 5D vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once with correct collection
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Call should use 5D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
# Verify results are ChunkMatch objects
assert len(result) == 2
assert 'document from 2D vector' in result
assert 'document from 3D vector' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document from 5D vector' in chunk_ids
assert 'another 5D document' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'utf8_user'
mock_message.collection = 'utf8_collection'
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 2
# Verify UTF-8 content works correctly
assert 'Document with UTF-8: café, naïve, résumé' in result
assert 'Chinese text: 你好世界' in result
# Verify UTF-8 content works correctly in ChunkMatch objects
chunk_ids = [r.chunk_id for r in result]
assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
assert 'Chinese text: 你好世界' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 0
# Result should contain all returned documents
# Result should contain all returned documents as ChunkMatch objects
assert len(result) == 1
assert result[0] == 'document chunk'
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == 'document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with large limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 1000 # Large limit
mock_message.user = 'large_user'
mock_message.collection = 'large_collection'
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 1000
# Result should contain all available documents
# Result should contain all available documents as ChunkMatch objects
assert len(result) == 2
assert 'document 1' in result
assert 'document 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document 1' in chunk_ids
assert 'document 2' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -508,7 +504,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'payload_user'
mock_message.collection = 'payload_collection'

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch
class TestMilvusGraphEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
)
return query
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify results are converted to Term objects
# Verify results are converted to EntityMatch objects
assert len(result) == 3
assert isinstance(result[0], Term)
assert result[0].iri == "http://example.com/entity1"
assert result[0].type == IRI
assert isinstance(result[1], Term)
assert result[1].iri == "http://example.com/entity2"
assert result[1].type == IRI
assert isinstance(result[2], Term)
assert result[2].value == "literal entity"
assert result[2].type == LITERAL
assert isinstance(result[0], EntityMatch)
assert result[0].entity.iri == "http://example.com/entity1"
assert result[0].entity.type == IRI
assert isinstance(result[1], EntityMatch)
assert result[1].entity.iri == "http://example.com/entity2"
assert result[1].entity.type == IRI
assert isinstance(result[2], EntityMatch)
assert result[2].entity.value == "literal entity"
assert result[2].entity.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
"""Test querying graph embeddings with multiple vectors"""
async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=3
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results - different results for each vector
mock_results_1 = [
# Mock search results with multiple entities
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results are deduplicated and limited
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
)
# Verify results are EntityMatch objects
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=2
)
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication(self, processor):
"""Test that duplicate entities are properly deduplicated"""
async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results with duplicates
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, # New
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=2
)
# Mock search results - first vector returns enough results
mock_results_1 = [
# Mock search results in specific order
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results_1
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
# Verify results are in the same order as returned by the store
assert len(result) == 3
assert result[0].entity.iri == "http://example.com/entity1"
assert result[1].entity.iri == "http://example.com/entity2"
assert result[2].entity.iri == "http://example.com/entity3"
@pytest.mark.asyncio
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
)
# Verify results are limited
# Mock search results - returns more results than limit
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
)
# Verify results are limited to requested amount
assert len(result) == 2
@pytest.mark.asyncio
@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
vector=[],
limit=5
)
@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify all results are properly typed
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.type == IRI]
uri_results = [r for r in result if r.entity.type == IRI]
assert len(uri_results) == 2
uri_values = [r.iri for r in uri_results]
uri_values = [r.entity.iri for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.type == IRI]
literal_results = [r for r in result if not r.entity.type == IRI]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
literal_values = [r.entity.value for r in literal_results]
assert "literal entity text" in literal_values
assert "another literal" in literal_values
@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=0
)
@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying graph embeddings with different vector dimensions"""
async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
# Mock search results
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "entity_2d" in entity_values
assert "entity_4d" in entity_values
assert "entity_3d" in entity_values
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once()
# Verify results
assert len(result) == 2
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values

View file

@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
class TestPineconeGraphEmbeddingsQueryProcessor:
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
include_metadata=True
)
# Verify results
# Verify results use EntityMatch structure
assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1'
assert entities[0].type == IRI
assert entities[1].value == 'entity2'
assert entities[1].type == LITERAL
assert entities[2].value == 'http://example.org/entity3'
assert entities[2].type == IRI
assert entities[0].entity.iri == 'http://example.org/entity1'
assert entities[0].entity.type == IRI
assert entities[1].entity.value == 'entity2'
assert entities[1].entity.type == LITERAL
assert entities[2].entity.iri == 'http://example.org/entity3'
assert entities[2].entity.type == IRI
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
"""Test querying graph embeddings with multiple vectors"""
async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
"""Test basic graph embeddings query"""
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query results
mock_results1 = MagicMock()
mock_results1.matches = [
# Query results with distinct entities
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'})
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify deduplication occurred
entity_values = [e.value for e in entities]
# Verify query was made once
assert mock_index.query.call_count == 1
# Verify results with EntityMatch structure
entity_values = [e.entity.value for e in entities]
assert len(entity_values) == 3
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions using same index"""
async def test_query_graph_embeddings_2d_vector(self, processor):
"""Test querying with a 2D vector"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.vector = [0.1, 0.2] # 2D vector
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
# Mock single index that handles all dimensions
# Mock index
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
# Mock results for 2D vector query
mock_results = MagicMock()
mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
index_calls = processor.pinecone.Index.call_args_list
index_names = [call[0][0] for call in index_calls]
assert "t-test_user-test_collection-2" in index_names # 2D vector
assert "t-test_user-test_collection-4" in index_names # 4D vector
# Verify correct index used for 2D vector
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify query was made
assert mock_index.query.call_count == 1
# Verify results from both dimensions
entity_values = [e.value for e in entities]
# Verify results with EntityMatch structure
entity_values = [e.entity.value for e in entities]
assert 'entity_2d' in entity_values
assert 'entity_4d' in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.vector = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
"""Test that deduplication works correctly across multiple vector queries"""
async def test_query_graph_embeddings_deduplication_in_results(self, processor):
"""Test that deduplication works correctly within query results"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Both queries return overlapping results
mock_results1 = MagicMock()
mock_results1.matches = [
# Query returns results with some duplicates
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}),
MagicMock(metadata={'entity': 'entity4'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
MagicMock(metadata={'entity': 'entity5'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
entity_values = [e.value for e in entities]
entity_values = [e.entity.value for e in entities]
assert len(set(entity_values)) == 3 # All unique
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
async def test_query_graph_embeddings_respects_limit(self, processor):
"""Test that query respects limit parameter"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query returns enough results to meet limit
mock_results1 = MagicMock()
mock_results1.matches = [
# Query returns more results than limit
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.return_value = mock_results1
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Should only make one query since limit was reached
# Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once()
assert len(entities) == 2
@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'

View file

@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor
from trustgraph.schema import IRI, LITERAL
from trustgraph.schema import IRI, LITERAL, EntityMatch
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected entities
# Verify result contains expected EntityMatch objects
assert len(result) == 2
assert all(hasattr(entity, 'value') for entity in result)
entity_values = [entity.value for entity in result]
assert all(isinstance(entity, EntityMatch) for entity in result)
entity_values = [entity.entity.value for entity in result]
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify deduplication - entity2 appears in both results but should only appear once
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert len(set(entity_values)) == len(entity_values) # All unique
assert 'entity1' in entity_values
assert 'entity2' in entity_values
assert 'entity3' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2] # 2D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
# Call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert 'entity2d' in entity_values
assert 'entity3d' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'uri_user'
mock_message.collection = 'uri_collection'
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 3
# Check URI entities
uri_entities = [entity for entity in result if entity.type == IRI]
uri_entities = [entity for entity in result if entity.entity.type == IRI]
assert len(uri_entities) == 2
uri_values = [entity.iri for entity in uri_entities]
uri_values = [entity.entity.iri for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities
regular_entities = [entity for entity in result if entity.type == LITERAL]
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity'
assert regular_entities[0].entity.value == 'regular entity'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# With zero limit, the logic still adds one entity before checking the limit
# So it returns one result (current behavior, not ideal but actual)
assert len(result) == 1
assert result[0].value == 'entity1'
assert result[0].entity.value == 'entity1'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')

View file

@ -175,9 +175,14 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock document embeddings returns chunk_ids
test_chunk_ids = ["doc/c1", "doc/c2"]
mock_doc_embeddings_client.query.return_value = test_chunk_ids
# Mock document embeddings returns ChunkMatch objects
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c1"
mock_match1.score = 0.95
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c2"
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
@ -195,9 +200,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify doc embeddings client was called correctly (with extracted vectors)
# Verify doc embeddings client was called correctly (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
vector=test_vectors,
limit=15,
user="test_user",
collection="test_collection"
@ -218,11 +223,16 @@ class TestQuery:
# Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]]
test_chunk_ids = ["doc/c3", "doc/c4"]
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c4"
mock_match2.score = 0.8
expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = test_chunk_ids
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag
@ -245,9 +255,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify doc embeddings client was called (with extracted vectors)
# Verify doc embeddings client was called (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
vector=test_vectors,
limit=10,
user="test_user",
collection="test_collection"
@ -275,7 +285,10 @@ class TestQuery:
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag
@ -289,9 +302,9 @@ class TestQuery:
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query")
# Verify default parameters were used (vectors extracted from batch)
# Verify default parameters were used (vector extracted from batch)
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2]],
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
user="trustgraph", # Default user
collection="default" # Default collection
@ -316,7 +329,10 @@ class TestQuery:
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True
query = Query(
@ -347,7 +363,10 @@ class TestQuery:
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True
@ -487,7 +506,13 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors]
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
mock_matches = []
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock()
mock_match.chunk_id = chunk_id
mock_match.score = 0.9
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches
mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag
@ -511,7 +536,7 @@ class TestQuery:
mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with(
query_vectors,
vector=query_vectors,
limit=25,
user="research_user",
collection="ml_knowledge"

View file

@ -193,12 +193,20 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock entity objects that have string representation
# Mock EntityMatch objects with entity that has string representation
mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
@ -216,9 +224,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify graph embeddings client was called correctly (with extracted vectors)
# Verify graph embeddings client was called correctly (with extracted vector)
mock_graph_embeddings_client.query.assert_called_once_with(
vectors=test_vectors,
vector=test_vectors,
limit=25,
user="test_user",
collection="test_collection"

View file

@ -23,11 +23,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk_id="This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
chunk2 = ChunkEmbeddings(
chunk_id="This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.chunks = [chunk1, chunk2]
@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk_id="Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called for each vector with user/collection parameters
expected_calls = [
([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
# Verify insert was called once for the single chunk with its vector
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'
)
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message)
# Verify insert was called for each vector of each chunk with user/collection parameters
# Verify insert was called once per chunk with user/collection parameters
expected_calls = [
# Chunk 1 vectors
([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
# Chunk 2 vectors
# Chunk 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
# Chunk 2 - single vector
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 3
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
@ -137,7 +127,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -156,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id=None,
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -177,15 +167,15 @@ class TestMilvusDocEmbeddingsStorageProcessor:
valid_chunk = ChunkEmbeddings(
chunk_id="Valid document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
empty_chunk = ChunkEmbeddings(
chunk_id="",
vectors=[[0.4, 0.5, 0.6]]
vector=[0.4, 0.5, 0.6]
)
another_valid = ChunkEmbeddings(
chunk_id="Another valid chunk",
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.chunks = [valid_chunk, empty_chunk, another_valid]
@ -229,7 +219,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="Document with no vectors",
vectors=[]
vector=[]
)
message.chunks = [chunk]
@ -245,26 +235,31 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk_id="Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each chunk has a single vector of different dimensions
chunk1 = ChunkEmbeddings(
chunk_id="chunk/doc/2d",
vector=[0.1, 0.2] # 2D vector
)
message.chunks = [chunk]
chunk2 = ChunkEmbeddings(
chunk_id="chunk/doc/4d",
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
chunk3 = ChunkEmbeddings(
chunk_id="chunk/doc/3d",
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.chunks = [chunk1, chunk2, chunk3]
await processor.store_document_embeddings(message)
# Verify all vectors were inserted regardless of dimension with user/collection parameters
expected_calls = [
([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'),
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'),
([0.1, 0.2], "chunk/doc/2d", 'test_user', 'test_collection'),
([0.3, 0.4, 0.5, 0.6], "chunk/doc/4d", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "chunk/doc/3d", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
@ -283,7 +278,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="chunk/doc/unicode-éñ中文🚀",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -306,7 +301,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
long_chunk_id = "chunk/doc/" + "a" * 200
chunk = ChunkEmbeddings(
chunk_id=long_chunk_id,
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -327,7 +322,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id=" \n\t ",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -358,7 +353,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="Test content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -379,7 +374,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings(
chunk_id="User1 content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message1.chunks = [chunk1]
@ -390,7 +385,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings(
chunk_id="User2 content",
vectors=[[0.4, 0.5, 0.6]]
vector=[0.4, 0.5, 0.6]
)
message2.chunks = [chunk2]
@ -421,7 +416,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="Special chars test",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]

View file

@ -27,11 +27,11 @@ class TestPineconeDocEmbeddingsStorageProcessor:
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.chunks = [chunk1, chunk2]
@ -125,7 +125,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
message.chunks = [chunk]
@ -190,7 +190,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -222,7 +222,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -244,7 +244,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -266,7 +266,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"", # Empty bytes
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -286,37 +286,39 @@ class TestPineconeDocEmbeddingsStorageProcessor:
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each chunk has a single vector of different dimensions
chunk1 = ChunkEmbeddings(
chunk=b"Document chunk 1",
vector=[0.1, 0.2] # 2D vector
)
message.chunks = [chunk]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
chunk2 = ChunkEmbeddings(
chunk=b"Document chunk 2",
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
chunk3 = ChunkEmbeddings(
chunk=b"Document chunk 3",
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.chunks = [chunk1, chunk2, chunk3]
mock_index = MagicMock()
def mock_index_side_effect(name):
# All dimensions now use the same index name pattern
# Different dimensions will be handled within the same index
if "test_user" in name and "test_collection" in name:
return mock_index_2d # Just return one mock for all
return mock_index
return MagicMock()
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(message)
# Verify all vectors are now stored in the same index
# (Pinecone can handle mixed dimensions in the same index)
assert processor.pinecone.Index.call_count == 3 # Called once per vector
mock_index_2d.upsert.call_count == 3 # All upserts go to same index
# (Each chunk has a single vector, called once per chunk)
assert processor.pinecone.Index.call_count == 3 # Called once per chunk
assert mock_index.upsert.call_count == 3 # All upserts go to same index
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
@ -346,7 +348,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
vector=[]
)
message.chunks = [chunk]
@ -368,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -393,7 +395,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -419,7 +421,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -447,7 +449,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]

View file

@ -89,7 +89,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_chunk.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
mock_message.chunks = [mock_chunk]
@ -143,11 +143,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2]]
mock_chunk1.vector = [0.1, 0.2]
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.3, 0.4]]
mock_chunk2.vector = [0.3, 0.4]
mock_message.chunks = [mock_chunk1, mock_chunk2]
@ -175,8 +175,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple vectors per chunk"""
async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple chunks"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
@ -196,41 +196,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with chunk having multiple vectors
# Create mock message with multiple chunks, each having a single vector
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/multi-vector'
mock_chunk.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vector = [0.1, 0.2, 0.3]
mock_message.chunks = [mock_chunk]
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vector = [0.4, 0.5, 0.6]
mock_chunk3 = MagicMock()
mock_chunk3.chunk_id = 'doc/c3'
mock_chunk3.vector = [0.7, 0.8, 0.9]
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should be called 3 times (once per vector)
# Should be called 3 times (once per chunk)
assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
expected_data = [
([0.1, 0.2, 0.3], 'doc/c1'),
([0.4, 0.5, 0.6], 'doc/c2'),
([0.7, 0.8, 0.9], 'doc/c3')
]
for i, call in enumerate(upsert_calls):
point = call[1]['points'][0]
assert point.vector == expected_vectors[i]
assert point.payload['chunk_id'] == 'doc/multi-vector'
assert point.vector == expected_data[i][0]
assert point.payload['chunk_id'] == expected_data[i][1]
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
@ -256,7 +260,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk_id = "" # Empty chunk_id
mock_chunk_empty.vectors = [[0.1, 0.2]]
mock_chunk_empty.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk_empty]
@ -299,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions
mock_message.chunks = [mock_chunk]
@ -351,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2]]
mock_chunk.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk]
@ -389,7 +393,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
mock_chunk1.vector = [0.1, 0.2, 0.3]
mock_message1.chunks = [mock_chunk1]
@ -407,7 +411,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
mock_chunk2.vector = [0.4, 0.5, 0.6] # Same dimension (3)
mock_message2.chunks = [mock_chunk2]
@ -446,19 +450,20 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push)
processor.known_collections[('dim_user', 'dim_collection')] = {}
# Create mock message with different dimension vectors
# Create mock message with chunks of different dimensions
mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection'
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/dim-test'
mock_chunk.vectors = [
[0.1, 0.2], # 2 dimensions
[0.3, 0.4, 0.5] # 3 dimensions
]
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vector = [0.1, 0.2] # 2 dimensions
mock_message.chunks = [mock_chunk]
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vector = [0.3, 0.4, 0.5] # 3 dimensions
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
@ -526,7 +531,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
mock_chunk.vectors = [[0.1, 0.2]]
mock_chunk.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk]

View file

@ -23,11 +23,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
# Create test entities with embeddings
entity1 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity1'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
entity2 = EntityEmbeddings(
entity=Term(type=LITERAL, value='literal entity'),
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.entities = [entity1, entity2]
@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify insert was called for each vector with user/collection parameters
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
# Verify insert was called once with the full vector
processor.vecstore.insert.assert_called_once()
actual_call = processor.vecstore.insert.call_args_list[0]
assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
assert actual_call[0][1] == 'http://example.com/entity'
assert actual_call[0][2] == 'test_user'
assert actual_call[0][3] == 'test_collection'
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message)
# Verify insert was called for each vector of each entity with user/collection parameters
# Verify insert was called once per entity with user/collection parameters
expected_calls = [
# Entity 1 vectors
([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
# Entity 2 vectors
# Entity 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
# Entity 2 - single vector
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 3
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
@ -137,7 +130,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=''),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -156,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=None),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -175,17 +168,17 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
valid_entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/valid'),
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
chunk_id=''
)
empty_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=''),
vectors=[[0.4, 0.5, 0.6]],
vector=[0.4, 0.5, 0.6],
chunk_id=''
)
none_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=None),
vectors=[[0.7, 0.8, 0.9]],
vector=[0.7, 0.8, 0.9],
chunk_id=''
)
message.entities = [valid_entity, empty_entity, none_entity]
@ -222,7 +215,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[]
vector=[]
)
message.entities = [entity]
@ -238,26 +231,31 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each entity has a single vector of different dimensions
entity1 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity1'),
vector=[0.1, 0.2] # 2D vector
)
message.entities = [entity]
entity2 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity2'),
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
entity3 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity3'),
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.entities = [entity1, entity2, entity3]
await processor.store_graph_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], 'http://example.com/entity'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
([0.7, 0.8, 0.9], 'http://example.com/entity'),
([0.1, 0.2], 'http://example.com/entity1'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity2'),
([0.7, 0.8, 0.9], 'http://example.com/entity3'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
@ -274,11 +272,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
uri_entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
literal_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value='literal entity text'),
vectors=[[0.4, 0.5, 0.6]]
vector=[0.4, 0.5, 0.6]
)
message.entities = [uri_entity, literal_entity]

View file

@ -24,16 +24,20 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings
# Create test entity embeddings (each entity has a single vector)
entity1 = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3]
)
entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
entity=Value(value="http://example.org/entity2", is_uri=True),
vector=[0.4, 0.5, 0.6]
)
message.entities = [entity1, entity2]
entity3 = EntityEmbeddings(
entity=Value(value="entity3", is_uri=False),
vector=[0.7, 0.8, 0.9]
)
message.entities = [entity1, entity2, entity3]
return message
@ -122,27 +126,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
with patch('uuid.uuid4', side_effect=['id1']):
await processor.store_graph_embeddings(message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Verify upsert was called for the single vector
assert mock_index.upsert.call_count == 1
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
@ -190,7 +194,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -222,7 +226,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -244,7 +248,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -258,23 +262,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions to same index"""
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each entity has a single vector of different dimensions
entity1 = EntityEmbeddings(
entity=Value(value="entity1", is_uri=False),
vector=[0.1, 0.2] # 2D vector
)
message.entities = [entity]
entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
entity3 = EntityEmbeddings(
entity=Value(value="entity3", is_uri=False),
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.entities = [entity1, entity2, entity3]
# All vectors now use the same index (no dimension in name)
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
@ -322,7 +330,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[]
vector=[]
)
message.entities = [entity]
@ -344,7 +352,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -369,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]

View file

@ -70,7 +70,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity = MagicMock()
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'test_entity'
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_entity.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
mock_message.entities = [mock_entity]
@ -124,12 +124,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity1 = MagicMock()
mock_entity1.entity.type = IRI
mock_entity1.entity.iri = 'entity_one'
mock_entity1.vectors = [[0.1, 0.2]]
mock_entity1.vector = [0.1, 0.2]
mock_entity2 = MagicMock()
mock_entity2.entity.type = IRI
mock_entity2.entity.iri = 'entity_two'
mock_entity2.vectors = [[0.3, 0.4]]
mock_entity2.vector = [0.3, 0.4]
mock_message.entities = [mock_entity1, mock_entity2]
@ -157,14 +157,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with multiple vectors per entity"""
async def test_store_graph_embeddings_three_entities(self, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with three entities"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -177,42 +177,48 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with entity having multiple vectors
# Create mock message with three entities
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_entity = MagicMock()
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'multi_vector_entity'
mock_entity.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_message.entities = [mock_entity]
mock_entity1 = MagicMock()
mock_entity1.entity.type = IRI
mock_entity1.entity.iri = 'entity_one'
mock_entity1.vector = [0.1, 0.2, 0.3]
mock_entity2 = MagicMock()
mock_entity2.entity.type = IRI
mock_entity2.entity.iri = 'entity_two'
mock_entity2.vector = [0.4, 0.5, 0.6]
mock_entity3 = MagicMock()
mock_entity3.entity.type = IRI
mock_entity3.entity.iri = 'entity_three'
mock_entity3.vector = [0.7, 0.8, 0.9]
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
# Act
await processor.store_graph_embeddings(mock_message)
# Assert
# Should be called 3 times (once per vector)
# Should be called 3 times (once per entity)
assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed
# Verify all entities were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
expected_data = [
([0.1, 0.2, 0.3], 'entity_one'),
([0.4, 0.5, 0.6], 'entity_two'),
([0.7, 0.8, 0.9], 'entity_three')
]
for i, call in enumerate(upsert_calls):
point = call[1]['points'][0]
assert point.vector == expected_vectors[i]
assert point.payload['entity'] == 'multi_vector_entity'
assert point.vector == expected_data[i][0]
assert point.payload['entity'] == expected_data[i][1]
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
@ -238,11 +244,11 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity_empty = MagicMock()
mock_entity_empty.entity.type = LITERAL
mock_entity_empty.entity.value = "" # Empty string
mock_entity_empty.vectors = [[0.1, 0.2]]
mock_entity_empty.vector = [0.1, 0.2]
mock_entity_none = MagicMock()
mock_entity_none.entity = None # None entity
mock_entity_none.vectors = [[0.3, 0.4]]
mock_entity_none.vector = [0.3, 0.4]
mock_message.entities = [mock_entity_empty, mock_entity_none]

View file

@ -197,7 +197,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='customer_id',
index_value=['CUST001'],
text='CUST001',
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
embeddings_msg = RowEmbeddings(
@ -227,8 +227,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with multiple vectors"""
async def test_on_embeddings_single_vector(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with a single vector"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
@ -250,12 +250,12 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with multiple vectors
# Embedding with a single 6D vector
embedding = RowIndexEmbedding(
index_name='name',
index_value=['John Doe'],
text='John Doe',
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
embeddings_msg = RowEmbeddings(
@ -269,8 +269,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
# Should be called once for the single embedding
assert mock_qdrant_instance.upsert.call_count == 1
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
@ -299,7 +299,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='id',
index_value=['123'],
text='123',
vectors=[] # Empty vectors
vector=[] # Empty vector
)
embeddings_msg = RowEmbeddings(
@ -342,7 +342,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='id',
index_value=['123'],
text='123',
vectors=[[0.1, 0.2]]
vector=[0.1, 0.2]
)
embeddings_msg = RowEmbeddings(