Embeddings API scores (#671)

- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
cybermaggedon 2026-03-09 10:53:44 +00:00 committed by GitHub
parent 4fa7cc7d7c
commit f2ae0e8623
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 1339 additions and 1292 deletions

View file

@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services
import pytest import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error
from trustgraph.messaging.translators.embeddings_query import ( from trustgraph.messaging.translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsRequestTranslator,
DocumentEmbeddingsResponseTranslator DocumentEmbeddingsResponseTranslator
@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract:
"""Test that DocumentEmbeddingsRequest has expected fields""" """Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request # Create a request
request = DocumentEmbeddingsRequest( request = DocumentEmbeddingsRequest(
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3],
limit=10, limit=10,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
) )
# Verify all expected fields exist # Verify all expected fields exist
assert hasattr(request, 'vectors') assert hasattr(request, 'vector')
assert hasattr(request, 'limit') assert hasattr(request, 'limit')
assert hasattr(request, 'user') assert hasattr(request, 'user')
assert hasattr(request, 'collection') assert hasattr(request, 'collection')
# Verify field values # Verify field values
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] assert request.vector == [0.1, 0.2, 0.3]
assert request.limit == 10 assert request.limit == 10
assert request.user == "test_user" assert request.user == "test_user"
assert request.collection == "test_collection" assert request.collection == "test_collection"
@ -43,7 +43,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator() translator = DocumentEmbeddingsRequestTranslator()
data = { data = {
"vectors": [[0.1, 0.2], [0.3, 0.4]], "vector": [0.1, 0.2, 0.3, 0.4],
"limit": 5, "limit": 5,
"user": "custom_user", "user": "custom_user",
"collection": "custom_collection" "collection": "custom_collection"
@ -52,7 +52,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.to_pulsar(data) result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest) assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] assert result.vector == [0.1, 0.2, 0.3, 0.4]
assert result.limit == 5 assert result.limit == 5
assert result.user == "custom_user" assert result.user == "custom_user"
assert result.collection == "custom_collection" assert result.collection == "custom_collection"
@ -62,14 +62,14 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator() translator = DocumentEmbeddingsRequestTranslator()
data = { data = {
"vectors": [[0.1, 0.2]] "vector": [0.1, 0.2]
# No limit, user, or collection provided # No limit, user, or collection provided
} }
result = translator.to_pulsar(data) result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest) assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2]] assert result.vector == [0.1, 0.2]
assert result.limit == 10 # Default assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default assert result.collection == "default" # Default
@ -79,7 +79,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator() translator = DocumentEmbeddingsRequestTranslator()
request = DocumentEmbeddingsRequest( request = DocumentEmbeddingsRequest(
vectors=[[0.5, 0.6]], vector=[0.5, 0.6],
limit=20, limit=20,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -88,7 +88,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.from_pulsar(request) result = translator.from_pulsar(request)
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["vectors"] == [[0.5, 0.6]] assert result["vector"] == [0.5, 0.6]
assert result["limit"] == 20 assert result["limit"] == 20
assert result["user"] == "test_user" assert result["user"] == "test_user"
assert result["collection"] == "test_collection" assert result["collection"] == "test_collection"
@ -99,19 +99,25 @@ class TestDocumentEmbeddingsResponseContract:
def test_response_schema_fields(self): def test_response_schema_fields(self):
"""Test that DocumentEmbeddingsResponse has expected fields""" """Test that DocumentEmbeddingsResponse has expected fields"""
# Create a response with chunk_ids # Create a response with chunks
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["chunk1", "chunk2", "chunk3"] chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8),
ChunkMatch(chunk_id="chunk3", score=0.7)
]
) )
# Verify all expected fields exist # Verify all expected fields exist
assert hasattr(response, 'error') assert hasattr(response, 'error')
assert hasattr(response, 'chunk_ids') assert hasattr(response, 'chunks')
# Verify field values # Verify field values
assert response.error is None assert response.error is None
assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"] assert len(response.chunks) == 3
assert response.chunks[0].chunk_id == "chunk1"
assert response.chunks[0].score == 0.9
def test_response_schema_with_error(self): def test_response_schema_with_error(self):
"""Test response schema with error""" """Test response schema with error"""
@ -122,53 +128,59 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=error, error=error,
chunk_ids=[] chunks=[]
) )
assert response.error == error assert response.error == error
assert response.chunk_ids == [] assert response.chunks == []
def test_response_translator_from_pulsar_with_chunk_ids(self): def test_response_translator_from_pulsar_with_chunks(self):
"""Test response translator converts Pulsar schema with chunk_ids to dict""" """Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"] chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85),
ChunkMatch(chunk_id="doc3/c3", score=0.75)
]
) )
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" in result assert "chunks" in result
assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"] assert len(result["chunks"]) == 3
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
assert result["chunks"][0]["score"] == 0.95
def test_response_translator_from_pulsar_with_empty_chunk_ids(self): def test_response_translator_from_pulsar_with_empty_chunks(self):
"""Test response translator handles empty chunk_ids list""" """Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=[] chunks=[]
) )
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" in result assert "chunks" in result
assert result["chunk_ids"] == [] assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunk_ids(self): def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunk_ids""" """Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock() response = MagicMock()
response.chunk_ids = None response.chunks = None
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" not in result or result.get("chunk_ids") is None assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self): def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag""" """Test response translator with completion flag"""
@ -176,14 +188,18 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["chunk1", "chunk2"] chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8)
]
) )
result, is_final = translator.from_response_with_completion(response) result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" in result assert "chunks" in result
assert result["chunk_ids"] == ["chunk1", "chunk2"] assert len(result["chunks"]) == 2
assert result["chunks"][0]["chunk_id"] == "chunk1"
assert is_final is True # Document embeddings responses are always final assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self): def test_response_translator_to_pulsar_not_implemented(self):
@ -191,7 +207,7 @@ class TestDocumentEmbeddingsResponseContract:
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunk_ids": ["test"]}) translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
class TestDocumentEmbeddingsMessageCompatibility: class TestDocumentEmbeddingsMessageCompatibility:
@ -201,7 +217,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
"""Test complete request-response flow maintains data integrity""" """Test complete request-response flow maintains data integrity"""
# Create request # Create request
request_data = { request_data = {
"vectors": [[0.1, 0.2, 0.3]], "vector": [0.1, 0.2, 0.3],
"limit": 5, "limit": 5,
"user": "test_user", "user": "test_user",
"collection": "test_collection" "collection": "test_collection"
@ -214,7 +230,10 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Simulate service processing and creating response # Simulate service processing and creating response
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["doc1/c1", "doc2/c2"] chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85)
]
) )
# Convert response back to dict # Convert response back to dict
@ -224,8 +243,8 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify data integrity # Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest) assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict) assert isinstance(response_data, dict)
assert "chunk_ids" in response_data assert "chunks" in response_data
assert len(response_data["chunk_ids"]) == 2 assert len(response_data["chunks"]) == 2
def test_error_response_flow(self): def test_error_response_flow(self):
"""Test error response flow""" """Test error response flow"""
@ -237,7 +256,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=error, error=error,
chunk_ids=[] chunks=[]
) )
# Convert response to dict # Convert response to dict
@ -246,6 +265,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify error handling # Verify error handling
assert isinstance(response_data, dict) assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunk_ids # The translator doesn't include error in the dict, only chunks
assert "chunk_ids" in response_data assert "chunks" in response_data
assert response_data["chunk_ids"] == [] assert response_data["chunks"] == []

View file

@ -285,11 +285,11 @@ class TestStructuredEmbeddingsContracts:
collection="test_collection", collection="test_collection",
metadata=[] metadata=[]
) )
# Act # Act
embedding = StructuredObjectEmbedding( embedding = StructuredObjectEmbedding(
metadata=metadata, metadata=metadata,
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3],
schema_name="customer_records", schema_name="customer_records",
object_id="customer_123", object_id="customer_123",
field_embeddings={ field_embeddings={
@ -301,7 +301,7 @@ class TestStructuredEmbeddingsContracts:
# Assert # Assert
assert embedding.schema_name == "customer_records" assert embedding.schema_name == "customer_records"
assert embedding.object_id == "customer_123" assert embedding.object_id == "customer_123"
assert len(embedding.vectors) == 2 assert len(embedding.vector) == 3
assert len(embedding.field_embeddings) == 2 assert len(embedding.field_embeddings) == 2
assert "name" in embedding.field_embeddings assert "name" in embedding.field_embeddings

View file

@ -9,6 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
# Sample chunk content for testing - maps chunk_id to content # Sample chunk content for testing - maps chunk_id to content
@ -39,10 +40,14 @@ class TestDocumentRagIntegration:
@pytest.fixture @pytest.fixture
def mock_doc_embeddings_client(self): def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk IDs""" """Mock document embeddings client that returns chunk matches"""
client = AsyncMock() client = AsyncMock()
# Now returns chunk_ids instead of actual content # Returns ChunkMatch objects with chunk_id and score
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client return client
@pytest.fixture @pytest.fixture
@ -97,7 +102,7 @@ class TestDocumentRagIntegration:
mock_embeddings_client.embed.assert_called_once_with([query]) mock_embeddings_client.embed.assert_called_once_with([query])
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit, limit=doc_limit,
user=user, user=user,
collection=collection collection=collection
@ -298,7 +303,7 @@ class TestDocumentRagIntegration:
assert "DocumentRag initialized" in log_messages assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages assert "Computing embeddings..." in log_messages
assert "chunk_ids" in log_messages.lower() assert "chunks" in log_messages.lower()
assert "Invoking LLM..." in log_messages assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages assert "Query processing complete" in log_messages
@ -307,9 +312,9 @@ class TestDocumentRagIntegration:
async def test_document_rag_performance_with_large_document_set(self, document_rag, async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client): mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval""" """Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large chunk_id set (100 chunks) # Arrange - Mock large chunk match set (100 chunks)
large_chunk_ids = [f"doc/c{i}" for i in range(100)] large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_ids mock_doc_embeddings_client.query.return_value = large_chunk_matches
# Act # Act
import time import time

View file

@ -8,6 +8,7 @@ response delivery through the complete pipeline.
import pytest import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from tests.utils.streaming_assertions import ( from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid, assert_streaming_chunks_valid,
assert_callback_invoked, assert_callback_invoked,
@ -36,10 +37,14 @@ class TestDocumentRagStreaming:
@pytest.fixture @pytest.fixture
def mock_doc_embeddings_client(self): def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk IDs""" """Mock document embeddings client that returns chunk matches"""
client = AsyncMock() client = AsyncMock()
# Now returns chunk_ids instead of actual content # Returns ChunkMatch objects with chunk_id and score
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client return client
@pytest.fixture @pytest.fixture

View file

@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
@pytest.mark.integration @pytest.mark.integration
@ -35,9 +36,9 @@ class TestGraphRagIntegration:
"""Mock graph embeddings client that returns realistic entities""" """Mock graph embeddings client that returns realistic entities"""
client = AsyncMock() client = AsyncMock()
client.query.return_value = [ client.query.return_value = [
"http://trustgraph.ai/e/machine-learning", EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
"http://trustgraph.ai/e/artificial-intelligence", EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
"http://trustgraph.ai/e/neural-networks" EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
] ]
return client return client
@ -130,7 +131,7 @@ class TestGraphRagIntegration:
# 2. Should query graph embeddings to find relevant entities # 2. Should query graph embeddings to find relevant entities
mock_graph_embeddings_client.query.assert_called_once() mock_graph_embeddings_client.query.assert_called_once()
call_args = mock_graph_embeddings_client.query.call_args call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection assert call_args.kwargs['collection'] == collection

View file

@ -8,6 +8,7 @@ response delivery through the complete pipeline.
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from tests.utils.streaming_assertions import ( from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid, assert_streaming_chunks_valid,
assert_rag_streaming_chunks, assert_rag_streaming_chunks,
@ -33,7 +34,7 @@ class TestGraphRagStreaming:
"""Mock graph embeddings client""" """Mock graph embeddings client"""
client = AsyncMock() client = AsyncMock()
client.query.return_value = [ client.query.return_value = [
"http://trustgraph.ai/e/machine-learning", EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
] ]
return client return client

View file

@ -411,7 +411,7 @@ class TestKnowledgeGraphPipelineIntegration:
entities=[ entities=[
EntityEmbeddings( EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/entity"), entity=Term(type=IRI, iri="http://example.org/entity"),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
] ]
) )

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, call from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
class TestGraphRagStreamingProtocol: class TestGraphRagStreamingProtocol:
@ -25,7 +26,10 @@ class TestGraphRagStreamingProtocol:
def mock_graph_embeddings_client(self): def mock_graph_embeddings_client(self):
"""Mock graph embeddings client""" """Mock graph embeddings client"""
client = AsyncMock() client = AsyncMock()
client.query.return_value = ["entity1", "entity2"] client.query.return_value = [
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
]
return client return client
@pytest.fixture @pytest.fixture
@ -202,9 +206,12 @@ class TestDocumentRagStreamingProtocol:
@pytest.fixture @pytest.fixture
def mock_doc_embeddings_client(self): def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk IDs""" """Mock document embeddings client that returns chunk matches"""
client = AsyncMock() client = AsyncMock()
client.query.return_value = ["doc/c1", "doc/c2"] client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90)
]
return client return client
@pytest.fixture @pytest.fixture

View file

@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"] mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
# Mock the request method # Mock the request method
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act # Act
result = await client.query( result = await client.query(
vectors=vectors, vector=vector,
limit=10, limit=10,
user="test_user", user="test_user",
collection="test_collection", collection="test_collection",
timeout=30 timeout=30
) )
# Assert # Assert
assert result == ["chunk1", "chunk2", "chunk3"] assert result == ["chunk1", "chunk2", "chunk3"]
client.request.assert_called_once() client.request.assert_called_once()
call_args = client.request.call_args[0][0] call_args = client.request.call_args[0][0]
assert isinstance(call_args, DocumentEmbeddingsRequest) assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vectors == vectors assert call_args.vector == vector
assert call_args.limit == 10 assert call_args.limit == 10
assert call_args.user == "test_user" assert call_args.user == "test_user"
assert call_args.collection == "test_collection" assert call_args.collection == "test_collection"
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
# Act & Assert # Act & Assert
with pytest.raises(RuntimeError, match="Database connection failed"): with pytest.raises(RuntimeError, match="Database connection failed"):
await client.query( await client.query(
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunk_ids = [] mock_response.chunks = []
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
# Act # Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]]) result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert # Assert
assert result == [] assert result == []
@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunk_ids = ["test_chunk"] mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
# Act # Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]]) result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert # Assert
client.request.assert_called_once() client.request.assert_called_once()
@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunk_ids = ["chunk1"] mock_response.chunks = ["chunk1"]
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
# Act # Act
await client.query( await client.query(
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
timeout=60 timeout=60
) )
# Assert # Assert
assert client.request.call_args[1]["timeout"] == 60 assert client.request.call_args[1]["timeout"] == 60
@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient() client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None mock_response.error = None
mock_response.chunk_ids = ["test_chunk"] mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response) client.request = AsyncMock(return_value=mock_response)
# Act # Act
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger: with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
result = await client.query(vectors=[[0.1, 0.2, 0.3]]) result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert # Assert
mock_logger.debug.assert_called_once() mock_logger.debug.assert_called_once()
assert "Document embeddings response" in str(mock_logger.debug.call_args) assert "Document embeddings response" in str(mock_logger.debug.call_args)

View file

@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = ["chunk1", "chunk2", "chunk3"] mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act # Act
result = client.request( result = client.request(
vectors=vectors, vector=vector,
user="test_user", user="test_user",
collection="test_collection", collection="test_collection",
limit=10, limit=10,
timeout=300 timeout=300
) )
# Assert # Assert
assert result == ["chunk1", "chunk2", "chunk3"] assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with( client.call.assert_called_once_with(
user="test_user", user="test_user",
collection="test_collection", collection="test_collection",
vectors=vectors, vector=vector,
limit=10, limit=10,
timeout=300 timeout=300
) )
@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = ["test_chunk"] mock_response.chunks = ["test_chunk"]
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3]] vector = [0.1, 0.2, 0.3]
# Act # Act
result = client.request(vectors=vectors) result = client.request(vector=vector)
# Assert # Assert
assert result == ["test_chunk"] assert result == ["test_chunk"]
client.call.assert_called_once_with( client.call.assert_called_once_with(
user="trustgraph", user="trustgraph",
collection="default", collection="default",
vectors=vectors, vector=vector,
limit=10, limit=10,
timeout=300 timeout=300
) )
@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = [] mock_response.chunks = []
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
# Act # Act
result = client.request(vectors=[[0.1, 0.2, 0.3]]) result = client.request(vector=[0.1, 0.2, 0.3])
# Assert # Assert
assert result == [] assert result == []
@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = None mock_response.chunks = None
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
# Act # Act
result = client.request(vectors=[[0.1, 0.2, 0.3]]) result = client.request(vector=[0.1, 0.2, 0.3])
# Assert # Assert
assert result is None assert result is None
@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = ["chunk1"] mock_response.chunks = ["chunk1"]
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
# Act # Act
client.request( client.request(
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
timeout=600 timeout=600
) )
# Assert # Assert
assert client.call.call_args[1]["timeout"] == 600 assert client.call.call_args[1]["timeout"] == 600

View file

@ -98,7 +98,7 @@ def sample_graph_embeddings():
entities=[ entities=[
EntityEmbeddings( EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/john"), entity=Term(type=IRI, iri="http://example.org/john"),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
] ]
) )

View file

@ -108,7 +108,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
# Assert # Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"]) mock_fastembed_instance.embed.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default assert processor.cached_model_name == "test-model" # Still using default
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')

View file

@ -60,7 +60,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
model="test-model", model="test-model",
input=["test text"] input=["test text"]
) )
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -86,7 +86,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
model="custom-model", model="custom-model",
input=["test text"] input=["test text"]
) )
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.milvus.service import Processor from trustgraph.query.doc_embeddings.milvus.service import Processor
from trustgraph.schema import DocumentEmbeddingsRequest from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
class TestMilvusDocEmbeddingsQueryProcessor: class TestMilvusDocEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10 limit=10
) )
return query return query
@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5 [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
) )
# Verify results are document chunks # Verify results are ChunkMatch objects
assert len(result) == 3 assert len(result) == 3
assert result[0] == "First document chunk" assert isinstance(result[0], ChunkMatch)
assert result[1] == "Second document chunk" assert result[0].chunk_id == "First document chunk"
assert result[2] == "Third document chunk" assert result[1].chunk_id == "Second document chunk"
assert result[2].chunk_id == "Third document chunk"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor): async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with multiple vectors""" """Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3 limit=3
) )
# Mock search results - different results for each vector # Mock search results
mock_results_1 = [ mock_results = [
{"entity": {"chunk_id": "Document from first vector"}}, {"entity": {"chunk_id": "First document"}},
{"entity": {"chunk_id": "Another doc from first vector"}}, {"entity": {"chunk_id": "Second document"}},
{"entity": {"chunk_id": "Third document"}},
] ]
mock_results_2 = [ processor.vecstore.search.return_value = mock_results
{"entity": {"chunk_id": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters including user/collection # Verify search was called once with the full vector
expected_calls = [ processor.vecstore.search.assert_called_once_with(
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}), )
]
assert processor.vecstore.search.call_count == 2 # Verify results are ChunkMatch objects
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results from all vectors are combined
assert len(result) == 3 assert len(result) == 3
assert "Document from first vector" in result chunk_ids = [r.chunk_id for r in result]
assert "Another doc from first vector" in result assert "First document" in chunk_ids
assert "Document from second vector" in result assert "Second document" in chunk_ids
assert "Third document" in chunk_ids
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor): async def test_query_document_embeddings_with_limit(self, processor):
@ -141,7 +135,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=2 limit=2
) )
@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[], vector=[],
limit=5 limit=5
) )
@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -211,7 +205,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -225,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Verify Unicode content is preserved # Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3 assert len(result) == 3
assert "Document with Unicode: éñ中文🚀" in result chunk_ids = [r.chunk_id for r in result]
assert "Regular ASCII document" in result assert "Document with Unicode: éñ中文🚀" in chunk_ids
assert "Document with émojis: 😀🎉" in result assert "Regular ASCII document" in chunk_ids
assert "Document with émojis: 😀🎉" in chunk_ids
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_large_documents(self, processor): async def test_query_document_embeddings_large_documents(self, processor):
@ -237,7 +232,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -251,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Verify large content is preserved # Verify large content is preserved in ChunkMatch objects
assert len(result) == 2 assert len(result) == 2
assert large_doc in result chunk_ids = [r.chunk_id for r in result]
assert "Small document" in result assert large_doc in chunk_ids
assert "Small document" in chunk_ids
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_special_characters(self, processor): async def test_query_document_embeddings_special_characters(self, processor):
@ -262,7 +258,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -276,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Verify special characters are preserved # Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3 assert len(result) == 3
assert "Document with \"quotes\" and 'apostrophes'" in result chunk_ids = [r.chunk_id for r in result]
assert "Document with\nnewlines\tand\ttabs" in result assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
assert "Document with special chars: @#$%^&*()" in result assert "Document with\nnewlines\tand\ttabs" in chunk_ids
assert "Document with special chars: @#$%^&*()" in chunk_ids
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor): async def test_query_document_embeddings_zero_limit(self, processor):
@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=0 limit=0
) )
@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=-1 limit=-1
) )
@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[ vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5 limit=5
) )
# Mock search results for each vector # Mock search results
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}] mock_results = [
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}] {"entity": {"chunk_id": "Document 1"}},
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}] {"entity": {"chunk_id": "Document 2"}},
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] ]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Verify all vectors were searched # Verify search was called with the vector
assert processor.vecstore.search.call_count == 3 processor.vecstore.search.assert_called_once()
# Verify results from all dimensions # Verify results are ChunkMatch objects
assert len(result) == 3 assert len(result) == 2
assert "Document from 2D vector" in result chunk_ids = [r.chunk_id for r in result]
assert "Document from 4D vector" in result assert "Document 1" in chunk_ids
assert "Document from 3D vector" in result assert "Document 2" in chunk_ids
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_duplicate_documents(self, processor): async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with duplicate documents in results""" """Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5 limit=5
) )
# Mock search results with duplicates across vectors # Mock search results with multiple documents
mock_results_1 = [ mock_results = [
{"entity": {"chunk_id": "Document A"}}, {"entity": {"chunk_id": "Document A"}},
{"entity": {"chunk_id": "Document B"}}, {"entity": {"chunk_id": "Document B"}},
]
mock_results_2 = [
{"entity": {"chunk_id": "Document B"}}, # Duplicate
{"entity": {"chunk_id": "Document C"}}, {"entity": {"chunk_id": "Document C"}},
] ]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Note: Unlike graph embeddings, doc embeddings don't deduplicate # Verify results are ChunkMatch objects
# This preserves ranking and allows multiple occurrences assert len(result) == 3
assert len(result) == 4 chunk_ids = [r.chunk_id for r in result]
assert result.count("Document B") == 2 # Should appear twice assert "Document A" in chunk_ids
assert "Document A" in result assert "Document B" in chunk_ids
assert "Document C" in result assert "Document C" in chunk_ids
def test_add_args_method(self): def test_add_args_method(self):
"""Test that add_args properly configures argument parser""" """Test that add_args properly configures argument parser"""

View file

@ -103,7 +103,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_single_vector(self, processor): async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector""" """Test querying document embeddings with a single vector"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 3 message.limit = 3
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_limit_handling(self, processor): async def test_query_document_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter""" """Test that query respects the limit parameter"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 2 message.limit = 2
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -208,7 +208,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_zero_limit(self, processor): async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results""" """Test querying with zero limit returns empty results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 0 message.limit = 0
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -226,7 +226,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_negative_limit(self, processor): async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results""" """Test querying with negative limit returns empty results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = -1 message.limit = -1
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -285,7 +285,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_vectors_list(self, processor): async def test_query_document_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list""" """Test querying with empty vectors list"""
message = MagicMock() message = MagicMock()
message.vectors = [] message.vector = []
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -304,7 +304,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_no_results(self, processor): async def test_query_document_embeddings_no_results(self, processor):
"""Test querying when index returns no results""" """Test querying when index returns no results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -325,7 +325,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_unicode_content(self, processor): async def test_query_document_embeddings_unicode_content(self, processor):
"""Test querying document embeddings with Unicode content results""" """Test querying document embeddings with Unicode content results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 2 message.limit = 2
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -351,7 +351,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_large_content(self, processor): async def test_query_document_embeddings_large_content(self, processor):
"""Test querying document embeddings with large content results""" """Test querying document embeddings with large content results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 1 message.limit = 1
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -377,7 +377,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_mixed_content_types(self, processor): async def test_query_document_embeddings_mixed_content_types(self, processor):
"""Test querying document embeddings with mixed content types""" """Test querying document embeddings with mixed content types"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -409,7 +409,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_exception_handling(self, processor): async def test_query_document_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised""" """Test that exceptions are properly raised"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -425,7 +425,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_index_access_failure(self, processor): async def test_query_document_embeddings_index_access_failure(self, processor):
"""Test handling of index access failure""" """Test handling of index access failure"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test # Import the service under test
from trustgraph.query.doc_embeddings.qdrant.service import Processor from trustgraph.query.doc_embeddings.qdrant.service import Processor
from trustgraph.schema import ChunkMatch
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]] mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'test_user' mock_message.user = 'test_user'
mock_message.collection = 'test_collection' mock_message.collection = 'test_collection'
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True with_payload=True
) )
# Verify result contains expected documents # Verify result contains expected ChunkMatch objects
assert len(result) == 2 assert len(result) == 2
# Results should be strings (document chunks) # Results should be ChunkMatch objects
assert isinstance(result[0], str) assert isinstance(result[0], ChunkMatch)
assert isinstance(result[1], str) assert isinstance(result[1], ChunkMatch)
# Verify content # Verify content
assert result[0] == 'first document chunk' assert result[0].chunk_id == 'first document chunk'
assert result[1] == 'second document chunk' assert result[1].chunk_id == 'second document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client): async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with multiple vectors""" """Test querying document embeddings returns multiple results"""
# Arrange # Arrange
mock_base_init.return_value = None mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses for different vectors # Mock query response with multiple results
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from vector 1'} mock_point1.payload = {'chunk_id': 'document chunk 1'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from vector 2'} mock_point2.payload = {'chunk_id': 'document chunk 2'}
mock_point3 = MagicMock() mock_point3 = MagicMock()
mock_point3.payload = {'chunk_id': 'another document from vector 2'} mock_point3.payload = {'chunk_id': 'document chunk 3'}
mock_response1 = MagicMock() mock_response = MagicMock()
mock_response1.points = [mock_point1] mock_response.points = [mock_point1, mock_point2, mock_point3]
mock_response2 = MagicMock() mock_qdrant_instance.query_points.return_value = mock_response
mock_response2.points = [mock_point2, mock_point3]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
config = { config = {
'taskgroup': AsyncMock(), 'taskgroup': AsyncMock(),
'id': 'test-processor' 'id': 'test-processor'
} }
processor = Processor(**config) processor = Processor(**config)
# Create mock message with multiple vectors # Create mock message with single vector
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 3 mock_message.limit = 3
mock_message.user = 'multi_user' mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection' mock_message.collection = 'multi_collection'
# Act # Act
result = await processor.query_document_embeddings(mock_message) result = await processor.query_document_embeddings(mock_message)
# Assert # Assert
# Verify query was called twice # Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 2 assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors) # Verify collection was queried correctly
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2] assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify results are ChunkMatch objects
# Verify results from both vectors are combined
assert len(result) == 3 assert len(result) == 3
assert 'document from vector 1' in result chunk_ids = [r.chunk_id for r in result]
assert 'document from vector 2' in result assert 'document chunk 1' in chunk_ids
assert 'another document from vector 2' in result assert 'document chunk 2' in chunk_ids
assert 'document chunk 3' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit # Create mock message with limit
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]] mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user' mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection' mock_message.collection = 'limit_collection'
@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'empty_user' mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection' mock_message.collection = 'empty_collection'
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client): async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with different vector dimensions""" """Test querying document embeddings with a higher dimension vector"""
# Arrange # Arrange
mock_base_init.return_value = None mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses # Mock query response
mock_point1 = MagicMock() mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from 2D vector'} mock_point1.payload = {'chunk_id': 'document from 5D vector'}
mock_point2 = MagicMock() mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from 3D vector'} mock_point2.payload = {'chunk_id': 'another 5D document'}
mock_response1 = MagicMock() mock_response = MagicMock()
mock_response1.points = [mock_point1] mock_response.points = [mock_point1, mock_point2]
mock_response2 = MagicMock() mock_qdrant_instance.query_points.return_value = mock_response
mock_response2.points = [mock_point2]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
config = { config = {
'taskgroup': AsyncMock(), 'taskgroup': AsyncMock(),
'id': 'test-processor' 'id': 'test-processor'
} }
processor = Processor(**config) processor = Processor(**config)
# Create mock message with different dimension vectors # Create mock message with 5D vector
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'dim_user' mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection' mock_message.collection = 'dim_collection'
# Act # Act
result = await processor.query_document_embeddings(mock_message) result = await processor.query_document_embeddings(mock_message)
# Assert # Assert
# Verify query was called twice with different collections # Verify query was called once with correct collection
assert mock_qdrant_instance.query_points.call_count == 2 assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection # Call should use 5D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
assert calls[0][1]['query'] == [0.1, 0.2] assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
# Second call should use 3D collection # Verify results are ChunkMatch objects
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
assert len(result) == 2 assert len(result) == 2
assert 'document from 2D vector' in result chunk_ids = [r.chunk_id for r in result]
assert 'document from 3D vector' in result assert 'document from 5D vector' in chunk_ids
assert 'another 5D document' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'utf8_user' mock_message.user = 'utf8_user'
mock_message.collection = 'utf8_collection' mock_message.collection = 'utf8_collection'
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert # Assert
assert len(result) == 2 assert len(result) == 2
# Verify UTF-8 content works correctly # Verify UTF-8 content works correctly in ChunkMatch objects
assert 'Document with UTF-8: café, naïve, résumé' in result chunk_ids = [r.chunk_id for r in result]
assert 'Chinese text: 你好世界' in result assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
assert 'Chinese text: 你好世界' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'error_user' mock_message.user = 'error_user'
mock_message.collection = 'error_collection' mock_message.collection = 'error_collection'
@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit # Create mock message with zero limit
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 0 mock_message.limit = 0
mock_message.user = 'zero_user' mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection' mock_message.collection = 'zero_collection'
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once() mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 0 assert call_args[1]['limit'] == 0
# Result should contain all returned documents # Result should contain all returned documents as ChunkMatch objects
assert len(result) == 1 assert len(result) == 1
assert result[0] == 'document chunk' assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == 'document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with large limit # Create mock message with large limit
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 1000 # Large limit mock_message.limit = 1000 # Large limit
mock_message.user = 'large_user' mock_message.user = 'large_user'
mock_message.collection = 'large_collection' mock_message.collection = 'large_collection'
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once() mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 1000 assert call_args[1]['limit'] == 1000
# Result should contain all available documents # Result should contain all available documents as ChunkMatch objects
assert len(result) == 2 assert len(result) == 2
assert 'document 1' in result chunk_ids = [r.chunk_id for r in result]
assert 'document 2' in result assert 'document 1' in chunk_ids
assert 'document 2' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -508,7 +504,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'payload_user' mock_message.user = 'payload_user'
mock_message.collection = 'payload_collection' mock_message.collection = 'payload_collection'

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch
class TestMilvusGraphEmbeddingsQueryProcessor: class TestMilvusGraphEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10 limit=10
) )
return query return query
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10 [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
) )
# Verify results are converted to Term objects # Verify results are converted to EntityMatch objects
assert len(result) == 3 assert len(result) == 3
assert isinstance(result[0], Term) assert isinstance(result[0], EntityMatch)
assert result[0].iri == "http://example.com/entity1" assert result[0].entity.iri == "http://example.com/entity1"
assert result[0].type == IRI assert result[0].entity.type == IRI
assert isinstance(result[1], Term) assert isinstance(result[1], EntityMatch)
assert result[1].iri == "http://example.com/entity2" assert result[1].entity.iri == "http://example.com/entity2"
assert result[1].type == IRI assert result[1].entity.type == IRI
assert isinstance(result[2], Term) assert isinstance(result[2], EntityMatch)
assert result[2].value == "literal entity" assert result[2].entity.value == "literal entity"
assert result[2].type == LITERAL assert result[2].entity.type == LITERAL
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor): async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings with multiple vectors""" """Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3 limit=5
) )
# Mock search results - different results for each vector # Mock search results with multiple entities
mock_results_1 = [ mock_results = [
{"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}}, {"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, {"entity": {"entity": "http://example.com/entity3"}},
] ]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query) result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters including user/collection # Verify search was called once with the full vector
expected_calls = [ processor.vecstore.search.assert_called_once_with(
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}), )
]
assert processor.vecstore.search.call_count == 2 # Verify results are EntityMatch objects
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results are deduplicated and limited
assert len(result) == 3 assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result] entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values assert "http://example.com/entity3" in entity_values
@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=2 limit=2
) )
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 2 assert len(result) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication(self, processor): async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that duplicate entities are properly deduplicated""" """Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5 limit=5
) )
# Mock search results with duplicates
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, # New
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio # Mock search results in specific order
async def test_query_graph_embeddings_early_termination_on_limit(self, processor): mock_results = [
"""Test that querying stops early when limit is reached"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=2
)
# Mock search results - first vector returns enough results
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}}, {"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}}, {"entity": {"entity": "http://example.com/entity3"}},
] ]
processor.vecstore.search.return_value = mock_results_1 processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query) result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached) # Verify results are in the same order as returned by the store
processor.vecstore.search.assert_called_once_with( assert len(result) == 3
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4 assert result[0].entity.iri == "http://example.com/entity1"
assert result[1].entity.iri == "http://example.com/entity2"
assert result[2].entity.iri == "http://example.com/entity3"
@pytest.mark.asyncio
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
) )
# Verify results are limited # Mock search results - returns more results than limit
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
)
# Verify results are limited to requested amount
assert len(result) == 2 assert len(result) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[], vector=[],
limit=5 limit=5
) )
@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify all results are properly typed # Verify all results are properly typed
assert len(result) == 4 assert len(result) == 4
# Check URI entities # Check URI entities
uri_results = [r for r in result if r.type == IRI] uri_results = [r for r in result if r.entity.type == IRI]
assert len(uri_results) == 2 assert len(uri_results) == 2
uri_values = [r.iri for r in uri_results] uri_values = [r.entity.iri for r in uri_results]
assert "http://example.com/uri_entity" in uri_values assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values assert "https://example.com/another_uri" in uri_values
# Check literal entities # Check literal entities
literal_results = [r for r in result if not r.type == IRI] literal_results = [r for r in result if not r.entity.type == IRI]
assert len(literal_results) == 2 assert len(literal_results) == 2
literal_values = [r.value for r in literal_results] literal_values = [r.entity.value for r in literal_results]
assert "literal entity text" in literal_values assert "literal entity text" in literal_values
assert "another literal" in literal_values assert "another literal" in literal_values
@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=0 limit=0
) )
@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 0 assert len(result) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor): async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with different vector dimensions""" """Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[ vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5 limit=5
) )
# Mock search results for each vector # Mock search results
mock_results_1 = [{"entity": {"entity": "entity_2d"}}] mock_results = [
mock_results_2 = [{"entity": {"entity": "entity_4d"}}] {"entity": {"entity": "http://example.com/entity1"}},
mock_results_3 = [{"entity": {"entity": "entity_3d"}}] {"entity": {"entity": "http://example.com/entity2"}},
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] ]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query) result = await processor.query_graph_embeddings(query)
# Verify all vectors were searched # Verify search was called once with the full vector
assert processor.vecstore.search.call_count == 3 processor.vecstore.search.assert_called_once()
# Verify results from all dimensions # Verify results
assert len(result) == 3 assert len(result) == 2
entity_values = [r.iri if r.type == IRI else r.value for r in result] entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "entity_2d" in entity_values assert "http://example.com/entity1" in entity_values
assert "entity_4d" in entity_values assert "http://example.com/entity2" in entity_values
assert "entity_3d" in entity_values

View file

@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True) pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
from trustgraph.query.graph_embeddings.pinecone.service import Processor from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Term, IRI, LITERAL from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
class TestPineconeGraphEmbeddingsQueryProcessor: class TestPineconeGraphEmbeddingsQueryProcessor:
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
def mock_query_message(self): def mock_query_message(self):
"""Create a mock query message for testing""" """Create a mock query message for testing"""
message = MagicMock() message = MagicMock()
message.vectors = [ message.vector = [0.1, 0.2, 0.3]
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor): async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector""" """Test querying graph embeddings with a single vector"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 3 message.limit = 3
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
include_metadata=True include_metadata=True
) )
# Verify results # Verify results use EntityMatch structure
assert len(entities) == 3 assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1' assert entities[0].entity.iri == 'http://example.org/entity1'
assert entities[0].type == IRI assert entities[0].entity.type == IRI
assert entities[1].value == 'entity2' assert entities[1].entity.value == 'entity2'
assert entities[1].type == LITERAL assert entities[1].entity.type == LITERAL
assert entities[2].value == 'http://example.org/entity3' assert entities[2].entity.iri == 'http://example.org/entity3'
assert entities[2].type == IRI assert entities[2].entity.type == IRI
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message): async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
"""Test querying graph embeddings with multiple vectors""" """Test basic graph embeddings query"""
# Mock index and query results # Mock index and query results
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
# First query results # Query results with distinct entities
mock_results1 = MagicMock() mock_results = MagicMock()
mock_results1.matches = [ mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}) MagicMock(metadata={'entity': 'entity2'}),
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}) MagicMock(metadata={'entity': 'entity3'})
] ]
mock_index.query.side_effect = [mock_results1, mock_results2] mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message) entities = await processor.query_graph_embeddings(mock_query_message)
# Verify both queries were made # Verify query was made once
assert mock_index.query.call_count == 2 assert mock_index.query.call_count == 1
# Verify deduplication occurred # Verify results with EntityMatch structure
entity_values = [e.value for e in entities] entity_values = [e.entity.value for e in entities]
assert len(entity_values) == 3 assert len(entity_values) == 3
assert 'entity1' in entity_values assert 'entity1' in entity_values
assert 'entity2' in entity_values assert 'entity2' in entity_values
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_limit_handling(self, processor): async def test_query_graph_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter""" """Test that query respects the limit parameter"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 2 message.limit = 2
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor): async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results""" """Test querying with zero limit returns empty results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 0 message.limit = 0
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_negative_limit(self, processor): async def test_query_graph_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results""" """Test querying with negative limit returns empty results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = -1 message.limit = -1
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == [] assert entities == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor): async def test_query_graph_embeddings_2d_vector(self, processor):
"""Test querying with vectors of different dimensions using same index""" """Test querying with a 2D vector"""
message = MagicMock() message = MagicMock()
message.vectors = [ message.vector = [0.1, 0.2] # 2D vector
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
# Mock single index that handles all dimensions # Mock index
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries # Mock results for 2D vector query
mock_results_2d = MagicMock() mock_results = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})] mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_results_4d = MagicMock() mock_index.query.return_value = mock_results
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
entities = await processor.query_graph_embeddings(message) entities = await processor.query_graph_embeddings(message)
# Verify different indexes used for different dimensions # Verify correct index used for 2D vector
assert processor.pinecone.Index.call_count == 2 processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
index_calls = processor.pinecone.Index.call_args_list
index_names = [call[0][0] for call in index_calls]
assert "t-test_user-test_collection-2" in index_names # 2D vector
assert "t-test_user-test_collection-4" in index_names # 4D vector
# Verify both queries were made # Verify query was made
assert mock_index.query.call_count == 2 assert mock_index.query.call_count == 1
# Verify results from both dimensions # Verify results with EntityMatch structure
entity_values = [e.value for e in entities] entity_values = [e.entity.value for e in entities]
assert 'entity_2d' in entity_values assert 'entity_2d' in entity_values
assert 'entity_4d' in entity_values
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors_list(self, processor): async def test_query_graph_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list""" """Test querying with empty vectors list"""
message = MagicMock() message = MagicMock()
message.vectors = [] message.vector = []
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_no_results(self, processor): async def test_query_graph_embeddings_no_results(self, processor):
"""Test querying when index returns no results""" """Test querying when index returns no results"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == [] assert entities == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor): async def test_query_graph_embeddings_deduplication_in_results(self, processor):
"""Test that deduplication works correctly across multiple vector queries""" """Test that deduplication works correctly within query results"""
message = MagicMock() message = MagicMock()
message.vectors = [ message.vector = [0.1, 0.2, 0.3]
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.limit = 3 message.limit = 3
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
# Both queries return overlapping results # Query returns results with some duplicates
mock_results1 = MagicMock() mock_results = MagicMock()
mock_results1.matches = [ mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), MagicMock(metadata={'entity': 'entity3'}),
MagicMock(metadata={'entity': 'entity4'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
MagicMock(metadata={'entity': 'entity5'})
] ]
mock_index.query.side_effect = [mock_results1, mock_results2] mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message) entities = await processor.query_graph_embeddings(message)
# Should get exactly 3 unique entities (respecting limit) # Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3 assert len(entities) == 3
entity_values = [e.value for e in entities] entity_values = [e.entity.value for e in entities]
assert len(set(entity_values)) == 3 # All unique assert len(set(entity_values)) == 3 # All unique
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor): async def test_query_graph_embeddings_respects_limit(self, processor):
"""Test that querying stops early when limit is reached""" """Test that query respects limit parameter"""
message = MagicMock() message = MagicMock()
message.vectors = [ message.vector = [0.1, 0.2, 0.3]
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.limit = 2 message.limit = 2
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
# First query returns enough results to meet limit # Query returns more results than limit
mock_results1 = MagicMock() mock_results = MagicMock()
mock_results1.matches = [ mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'}) MagicMock(metadata={'entity': 'entity3'})
] ]
mock_index.query.return_value = mock_results1 mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message) entities = await processor.query_graph_embeddings(message)
# Should only make one query since limit was reached # Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once() mock_index.query.assert_called_once()
assert len(entities) == 2 assert len(entities) == 2
@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor): async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised""" """Test that exceptions are properly raised"""
message = MagicMock() message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]] message.vector = [0.1, 0.2, 0.3]
message.limit = 5 message.limit = 5
message.user = 'test_user' message.user = 'test_user'
message.collection = 'test_collection' message.collection = 'test_collection'

View file

@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test # Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor from trustgraph.query.graph_embeddings.qdrant.service import Processor
from trustgraph.schema import IRI, LITERAL from trustgraph.schema import IRI, LITERAL, EntityMatch
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]] mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'test_user' mock_message.user = 'test_user'
mock_message.collection = 'test_collection' mock_message.collection = 'test_collection'
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True with_payload=True
) )
# Verify result contains expected entities # Verify result contains expected EntityMatch objects
assert len(result) == 2 assert len(result) == 2
assert all(hasattr(entity, 'value') for entity in result) assert all(isinstance(entity, EntityMatch) for entity in result)
entity_values = [entity.value for entity in result] entity_values = [entity.entity.value for entity in result]
assert 'entity1' in entity_values assert 'entity1' in entity_values
assert 'entity2' in entity_values assert 'entity2' in entity_values
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
# Create mock message with multiple vectors # Create mock message with single vector
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 3 mock_message.limit = 3
mock_message.user = 'multi_user' mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection' mock_message.collection = 'multi_collection'
# Act # Act
result = await processor.query_graph_embeddings(mock_message) result = await processor.query_graph_embeddings(mock_message)
# Assert # Assert
# Verify query was called twice # Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 2 assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors) # Verify collection was queried
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2] assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify results with EntityMatch structure
# Verify deduplication - entity2 appears in both results but should only appear once entity_values = [entity.entity.value for entity in result]
entity_values = [entity.value for entity in result]
assert len(set(entity_values)) == len(entity_values) # All unique assert len(set(entity_values)) == len(entity_values) # All unique
assert 'entity1' in entity_values assert 'entity1' in entity_values
assert 'entity2' in entity_values assert 'entity2' in entity_values
assert 'entity3' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit # Create mock message with limit
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]] mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user' mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection' mock_message.collection = 'limit_collection'
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'empty_user' mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection' mock_message.collection = 'empty_collection'
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
# Create mock message with different dimension vectors # Create mock message with single vector
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D mock_message.vector = [0.1, 0.2] # 2D vector
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'dim_user' mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection' mock_message.collection = 'dim_collection'
# Act # Act
result = await processor.query_graph_embeddings(mock_message) result = await processor.query_graph_embeddings(mock_message)
# Assert # Assert
# Verify query was called twice with different collections # Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 2 assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection # Call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2] assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection # Verify results with EntityMatch structure
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions entity_values = [entity.entity.value for entity in result]
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
entity_values = [entity.value for entity in result]
assert 'entity2d' in entity_values assert 'entity2d' in entity_values
assert 'entity3d' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'uri_user' mock_message.user = 'uri_user'
mock_message.collection = 'uri_collection' mock_message.collection = 'uri_collection'
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert # Assert
assert len(result) == 3 assert len(result) == 3
# Check URI entities # Check URI entities
uri_entities = [entity for entity in result if entity.type == IRI] uri_entities = [entity for entity in result if entity.entity.type == IRI]
assert len(uri_entities) == 2 assert len(uri_entities) == 2
uri_values = [entity.iri for entity in uri_entities] uri_values = [entity.entity.iri for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities # Check regular entities
regular_entities = [entity for entity in result if entity.type == LITERAL] regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
assert len(regular_entities) == 1 assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity' assert regular_entities[0].entity.value == 'regular entity'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 5 mock_message.limit = 5
mock_message.user = 'error_user' mock_message.user = 'error_user'
mock_message.collection = 'error_collection' mock_message.collection = 'error_collection'
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit # Create mock message with zero limit
mock_message = MagicMock() mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]] mock_message.vector = [0.1, 0.2]
mock_message.limit = 0 mock_message.limit = 0
mock_message.user = 'zero_user' mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection' mock_message.collection = 'zero_collection'
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# With zero limit, the logic still adds one entity before checking the limit # With zero limit, the logic still adds one entity before checking the limit
# So it returns one result (current behavior, not ideal but actual) # So it returns one result (current behavior, not ideal but actual)
assert len(result) == 1 assert len(result) == 1
assert result[0].value == 'entity1' assert result[0].entity.value == 'entity1'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')

View file

@ -175,9 +175,14 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors] mock_embeddings_client.embed.return_value = [test_vectors]
# Mock document embeddings returns chunk_ids # Mock document embeddings returns ChunkMatch objects
test_chunk_ids = ["doc/c1", "doc/c2"] mock_match1 = MagicMock()
mock_doc_embeddings_client.query.return_value = test_chunk_ids mock_match1.chunk_id = "doc/c1"
mock_match1.score = 0.95
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c2"
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query # Initialize Query
query = Query( query = Query(
@ -195,9 +200,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list) # Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query]) mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify doc embeddings client was called correctly (with extracted vectors) # Verify doc embeddings client was called correctly (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors, vector=test_vectors,
limit=15, limit=15,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -218,11 +223,16 @@ class TestQuery:
# Mock embeddings and document embeddings responses # Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0] # New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
test_chunk_ids = ["doc/c3", "doc/c4"] mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c4"
mock_match2.score = 0.8
expected_response = "This is the document RAG response" expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors] mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = test_chunk_ids mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag # Initialize DocumentRag
@ -245,9 +255,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list) # Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"]) mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify doc embeddings client was called (with extracted vectors) # Verify doc embeddings client was called (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors, vector=test_vectors,
limit=10, limit=10,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -275,7 +285,10 @@ class TestQuery:
# Mock responses (batch format) # Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = ["doc/c5"] mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response" mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag # Initialize DocumentRag
@ -289,9 +302,9 @@ class TestQuery:
# Call DocumentRag.query with minimal parameters # Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query") result = await document_rag.query("simple query")
# Verify default parameters were used (vectors extracted from batch) # Verify default parameters were used (vector extracted from batch)
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2]], vector=[[0.1, 0.2]],
limit=20, # Default doc_limit limit=20, # Default doc_limit
user="trustgraph", # Default user user="trustgraph", # Default user
collection="default" # Default collection collection="default" # Default collection
@ -316,7 +329,10 @@ class TestQuery:
# Mock responses (batch format) # Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_doc_embeddings_client.query.return_value = ["doc/c6"] mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True # Initialize Query with verbose=True
query = Query( query = Query(
@ -347,7 +363,10 @@ class TestQuery:
# Mock responses (batch format) # Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_doc_embeddings_client.query.return_value = ["doc/c7"] mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response" mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True # Initialize DocumentRag with verbose=True
@ -487,7 +506,13 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors] mock_embeddings_client.embed.return_value = [query_vectors]
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids mock_matches = []
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock()
mock_match.chunk_id = chunk_id
mock_match.score = 0.9
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches
mock_prompt_client.document_prompt.return_value = final_response mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag # Initialize DocumentRag
@ -511,7 +536,7 @@ class TestQuery:
mock_embeddings_client.embed.assert_called_once_with([query_text]) mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with( mock_doc_embeddings_client.query.assert_called_once_with(
query_vectors, vector=query_vectors,
limit=25, limit=25,
user="research_user", user="research_user",
collection="ml_knowledge" collection="ml_knowledge"

View file

@ -193,12 +193,20 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors] mock_embeddings_client.embed.return_value = [test_vectors]
# Mock entity objects that have string representation # Mock EntityMatch objects with entity that has string representation
mock_entity1 = MagicMock() mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1") mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock() mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2") mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2] mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query # Initialize Query
query = Query( query = Query(
@ -216,9 +224,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list) # Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query]) mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify graph embeddings client was called correctly (with extracted vectors) # Verify graph embeddings client was called correctly (with extracted vector)
mock_graph_embeddings_client.query.assert_called_once_with( mock_graph_embeddings_client.query.assert_called_once_with(
vectors=test_vectors, vector=test_vectors,
limit=25, limit=25,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"

View file

@ -23,11 +23,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Create test document embeddings # Create test document embeddings
chunk1 = ChunkEmbeddings( chunk1 = ChunkEmbeddings(
chunk_id="This is the first document chunk", chunk_id="This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
) )
chunk2 = ChunkEmbeddings( chunk2 = ChunkEmbeddings(
chunk_id="This is the second document chunk", chunk_id="This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]] vector=[0.7, 0.8, 0.9]
) )
message.chunks = [chunk1, chunk2] message.chunks = [chunk1, chunk2]
@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id="Test document content", chunk_id="Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings(message)
# Verify insert was called for each vector with user/collection parameters # Verify insert was called once for the single chunk with its vector
expected_calls = [ processor.vecstore.insert.assert_called_once_with(
([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'
([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'), )
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks""" """Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings(mock_message)
# Verify insert was called for each vector of each chunk with user/collection parameters # Verify insert was called once per chunk with user/collection parameters
expected_calls = [ expected_calls = [
# Chunk 1 vectors # Chunk 1 - single vector
([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'), ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), # Chunk 2 - single vector
# Chunk 2 vectors
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'), ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
] ]
assert processor.vecstore.insert.call_count == 3 assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i] actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec assert actual_call[0][0] == expected_vec
@ -137,7 +127,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id="", chunk_id="",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -156,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id=None, chunk_id=None,
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -177,15 +167,15 @@ class TestMilvusDocEmbeddingsStorageProcessor:
valid_chunk = ChunkEmbeddings( valid_chunk = ChunkEmbeddings(
chunk_id="Valid document content", chunk_id="Valid document content",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
empty_chunk = ChunkEmbeddings( empty_chunk = ChunkEmbeddings(
chunk_id="", chunk_id="",
vectors=[[0.4, 0.5, 0.6]] vector=[0.4, 0.5, 0.6]
) )
another_valid = ChunkEmbeddings( another_valid = ChunkEmbeddings(
chunk_id="Another valid chunk", chunk_id="Another valid chunk",
vectors=[[0.7, 0.8, 0.9]] vector=[0.7, 0.8, 0.9]
) )
message.chunks = [valid_chunk, empty_chunk, another_valid] message.chunks = [valid_chunk, empty_chunk, another_valid]
@ -229,7 +219,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id="Document with no vectors", chunk_id="Document with no vectors",
vectors=[] vector=[]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -245,26 +235,31 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( # Each chunk has a single vector of different dimensions
chunk_id="Document with mixed dimensions", chunk1 = ChunkEmbeddings(
vectors=[ chunk_id="chunk/doc/2d",
[0.1, 0.2], # 2D vector vector=[0.1, 0.2] # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
) )
message.chunks = [chunk] chunk2 = ChunkEmbeddings(
chunk_id="chunk/doc/4d",
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
chunk3 = ChunkEmbeddings(
chunk_id="chunk/doc/3d",
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.chunks = [chunk1, chunk2, chunk3]
await processor.store_document_embeddings(message) await processor.store_document_embeddings(message)
# Verify all vectors were inserted regardless of dimension with user/collection parameters # Verify all vectors were inserted regardless of dimension with user/collection parameters
expected_calls = [ expected_calls = [
([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'), ([0.1, 0.2], "chunk/doc/2d", 'test_user', 'test_collection'),
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'), ([0.3, 0.4, 0.5, 0.6], "chunk/doc/4d", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'), ([0.7, 0.8, 0.9], "chunk/doc/3d", 'test_user', 'test_collection'),
] ]
assert processor.vecstore.insert.call_count == 3 assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i] actual_call = processor.vecstore.insert.call_args_list[i]
@ -283,7 +278,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id="chunk/doc/unicode-éñ中文🚀", chunk_id="chunk/doc/unicode-éñ中文🚀",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -306,7 +301,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
long_chunk_id = "chunk/doc/" + "a" * 200 long_chunk_id = "chunk/doc/" + "a" * 200
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id=long_chunk_id, chunk_id=long_chunk_id,
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -327,7 +322,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id=" \n\t ", chunk_id=" \n\t ",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -358,7 +353,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id="Test content", chunk_id="Test content",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -379,7 +374,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message1.metadata.collection = 'collection1' message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings( chunk1 = ChunkEmbeddings(
chunk_id="User1 content", chunk_id="User1 content",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message1.chunks = [chunk1] message1.chunks = [chunk1]
@ -390,7 +385,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message2.metadata.collection = 'collection2' message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings( chunk2 = ChunkEmbeddings(
chunk_id="User2 content", chunk_id="User2 content",
vectors=[[0.4, 0.5, 0.6]] vector=[0.4, 0.5, 0.6]
) )
message2.chunks = [chunk2] message2.chunks = [chunk2]
@ -421,7 +416,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk_id="Special chars test", chunk_id="Special chars test",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]

View file

@ -27,11 +27,11 @@ class TestPineconeDocEmbeddingsStorageProcessor:
# Create test document embeddings # Create test document embeddings
chunk1 = ChunkEmbeddings( chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk", chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
) )
chunk2 = ChunkEmbeddings( chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk", chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]] vector=[0.7, 0.8, 0.9]
) )
message.chunks = [chunk1, chunk2] message.chunks = [chunk1, chunk2]
@ -125,7 +125,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Test document content", chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -190,7 +190,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Test document content", chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -222,7 +222,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"", chunk=b"",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -244,7 +244,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=None, chunk=None,
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -266,7 +266,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"", # Empty bytes chunk=b"", # Empty bytes
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -286,37 +286,39 @@ class TestPineconeDocEmbeddingsStorageProcessor:
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( # Each chunk has a single vector of different dimensions
chunk=b"Document with mixed dimensions", chunk1 = ChunkEmbeddings(
vectors=[ chunk=b"Document chunk 1",
[0.1, 0.2], # 2D vector vector=[0.1, 0.2] # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
) )
message.chunks = [chunk] chunk2 = ChunkEmbeddings(
chunk=b"Document chunk 2",
mock_index_2d = MagicMock() vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
mock_index_4d = MagicMock() )
mock_index_3d = MagicMock() chunk3 = ChunkEmbeddings(
chunk=b"Document chunk 3",
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.chunks = [chunk1, chunk2, chunk3]
mock_index = MagicMock()
def mock_index_side_effect(name): def mock_index_side_effect(name):
# All dimensions now use the same index name pattern # All dimensions now use the same index name pattern
# Different dimensions will be handled within the same index
if "test_user" in name and "test_collection" in name: if "test_user" in name and "test_collection" in name:
return mock_index_2d # Just return one mock for all return mock_index
return MagicMock() return MagicMock()
processor.pinecone.Index.side_effect = mock_index_side_effect processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(message) await processor.store_document_embeddings(message)
# Verify all vectors are now stored in the same index # Verify all vectors are now stored in the same index
# (Pinecone can handle mixed dimensions in the same index) # (Each chunk has a single vector, called once per chunk)
assert processor.pinecone.Index.call_count == 3 # Called once per vector assert processor.pinecone.Index.call_count == 3 # Called once per chunk
mock_index_2d.upsert.call_count == 3 # All upserts go to same index assert mock_index.upsert.call_count == 3 # All upserts go to same index
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor): async def test_store_document_embeddings_empty_chunks_list(self, processor):
@ -346,7 +348,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Document with no vectors", chunk=b"Document with no vectors",
vectors=[] vector=[]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -368,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Test document content", chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -393,7 +395,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=b"Test document content", chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -419,7 +421,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]
@ -447,7 +449,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
large_content = "A" * 10000 # 10KB of content large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'), chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.chunks = [chunk] message.chunks = [chunk]

View file

@ -89,7 +89,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions mock_chunk.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
@ -143,11 +143,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1' mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2]] mock_chunk1.vector = [0.1, 0.2]
mock_chunk2 = MagicMock() mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2' mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.3, 0.4]] mock_chunk2.vector = [0.3, 0.4]
mock_message.chunks = [mock_chunk1, mock_chunk2] mock_message.chunks = [mock_chunk1, mock_chunk2]
@ -175,8 +175,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client): async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple vectors per chunk""" """Test storing document embeddings with multiple chunks"""
# Arrange # Arrange
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
@ -196,41 +196,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push) # Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {} processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with chunk having multiple vectors # Create mock message with multiple chunks, each having a single vector
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'vector_user' mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection' mock_message.metadata.collection = 'vector_collection'
mock_chunk = MagicMock() mock_chunk1 = MagicMock()
mock_chunk.chunk_id = 'doc/multi-vector' mock_chunk1.chunk_id = 'doc/c1'
mock_chunk.vectors = [ mock_chunk1.vector = [0.1, 0.2, 0.3]
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_message.chunks = [mock_chunk] mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vector = [0.4, 0.5, 0.6]
mock_chunk3 = MagicMock()
mock_chunk3.chunk_id = 'doc/c3'
mock_chunk3.vector = [0.7, 0.8, 0.9]
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings(mock_message)
# Assert # Assert
# Should be called 3 times (once per vector) # Should be called 3 times (once per chunk)
assert mock_qdrant_instance.upsert.call_count == 3 assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed # Verify all vectors were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [ expected_data = [
[0.1, 0.2, 0.3], ([0.1, 0.2, 0.3], 'doc/c1'),
[0.4, 0.5, 0.6], ([0.4, 0.5, 0.6], 'doc/c2'),
[0.7, 0.8, 0.9] ([0.7, 0.8, 0.9], 'doc/c3')
] ]
for i, call in enumerate(upsert_calls): for i, call in enumerate(upsert_calls):
point = call[1]['points'][0] point = call[1]['points'][0]
assert point.vector == expected_vectors[i] assert point.vector == expected_data[i][0]
assert point.payload['chunk_id'] == 'doc/multi-vector' assert point.payload['chunk_id'] == expected_data[i][1]
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client): async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
@ -256,7 +260,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk_empty = MagicMock() mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk_id = "" # Empty chunk_id mock_chunk_empty.chunk_id = "" # Empty chunk_id
mock_chunk_empty.vectors = [[0.1, 0.2]] mock_chunk_empty.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk_empty] mock_message.chunks = [mock_chunk_empty]
@ -299,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/test-chunk' mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
@ -351,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/test-chunk' mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2]] mock_chunk.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
@ -389,7 +393,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1' mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]] mock_chunk1.vector = [0.1, 0.2, 0.3]
mock_message1.chunks = [mock_chunk1] mock_message1.chunks = [mock_chunk1]
@ -407,7 +411,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk2 = MagicMock() mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2' mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3) mock_chunk2.vector = [0.4, 0.5, 0.6] # Same dimension (3)
mock_message2.chunks = [mock_chunk2] mock_message2.chunks = [mock_chunk2]
@ -446,19 +450,20 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push) # Add collection to known_collections (simulates config push)
processor.known_collections[('dim_user', 'dim_collection')] = {} processor.known_collections[('dim_user', 'dim_collection')] = {}
# Create mock message with different dimension vectors # Create mock message with chunks of different dimensions
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'dim_user' mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection' mock_message.metadata.collection = 'dim_collection'
mock_chunk = MagicMock() mock_chunk1 = MagicMock()
mock_chunk.chunk_id = 'doc/dim-test' mock_chunk1.chunk_id = 'doc/c1'
mock_chunk.vectors = [ mock_chunk1.vector = [0.1, 0.2] # 2 dimensions
[0.1, 0.2], # 2 dimensions
[0.3, 0.4, 0.5] # 3 dimensions
]
mock_message.chunks = [mock_chunk] mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vector = [0.3, 0.4, 0.5] # 3 dimensions
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings(mock_message)
@ -526,7 +531,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock() mock_chunk = MagicMock()
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3' mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
mock_chunk.vectors = [[0.1, 0.2]] mock_chunk.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]

View file

@ -23,11 +23,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
# Create test entities with embeddings # Create test entities with embeddings
entity1 = EntityEmbeddings( entity1 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity1'), entity=Term(type=IRI, iri='http://example.com/entity1'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
) )
entity2 = EntityEmbeddings( entity2 = EntityEmbeddings(
entity=Term(type=LITERAL, value='literal entity'), entity=Term(type=LITERAL, value='literal entity'),
vectors=[[0.7, 0.8, 0.9]] vector=[0.7, 0.8, 0.9]
) )
message.entities = [entity1, entity2] message.entities = [entity1, entity2]
@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'), entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
) )
message.entities = [entity] message.entities = [entity]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings(message)
# Verify insert was called for each vector with user/collection parameters # Verify insert was called once with the full vector
expected_calls = [ processor.vecstore.insert.assert_called_once()
([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'), actual_call = processor.vecstore.insert.call_args_list[0]
([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'), assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
] assert actual_call[0][1] == 'http://example.com/entity'
assert actual_call[0][2] == 'test_user'
assert processor.vecstore.insert.call_count == 2 assert actual_call[0][3] == 'test_collection'
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities""" """Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message) await processor.store_graph_embeddings(mock_message)
# Verify insert was called for each vector of each entity with user/collection parameters # Verify insert was called once per entity with user/collection parameters
expected_calls = [ expected_calls = [
# Entity 1 vectors # Entity 1 - single vector
([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'), ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), # Entity 2 - single vector
# Entity 2 vectors
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'), ([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
] ]
assert processor.vecstore.insert.call_count == 3 assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i] actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec assert actual_call[0][0] == expected_vec
@ -137,7 +130,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=''), entity=Term(type=LITERAL, value=''),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]
@ -156,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=None), entity=Term(type=LITERAL, value=None),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]
@ -175,17 +168,17 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
valid_entity = EntityEmbeddings( valid_entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/valid'), entity=Term(type=IRI, iri='http://example.com/valid'),
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
chunk_id='' chunk_id=''
) )
empty_entity = EntityEmbeddings( empty_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=''), entity=Term(type=LITERAL, value=''),
vectors=[[0.4, 0.5, 0.6]], vector=[0.4, 0.5, 0.6],
chunk_id='' chunk_id=''
) )
none_entity = EntityEmbeddings( none_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=None), entity=Term(type=LITERAL, value=None),
vectors=[[0.7, 0.8, 0.9]], vector=[0.7, 0.8, 0.9],
chunk_id='' chunk_id=''
) )
message.entities = [valid_entity, empty_entity, none_entity] message.entities = [valid_entity, empty_entity, none_entity]
@ -222,7 +215,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'), entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[] vector=[]
) )
message.entities = [entity] message.entities = [entity]
@ -238,26 +231,31 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( # Each entity has a single vector of different dimensions
entity=Term(type=IRI, iri='http://example.com/entity'), entity1 = EntityEmbeddings(
vectors=[ entity=Term(type=IRI, iri='http://example.com/entity1'),
[0.1, 0.2], # 2D vector vector=[0.1, 0.2] # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
) )
message.entities = [entity] entity2 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity2'),
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
entity3 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity3'),
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.entities = [entity1, entity2, entity3]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings(message)
# Verify all vectors were inserted regardless of dimension # Verify all vectors were inserted regardless of dimension
expected_calls = [ expected_calls = [
([0.1, 0.2], 'http://example.com/entity'), ([0.1, 0.2], 'http://example.com/entity1'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'), ([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity2'),
([0.7, 0.8, 0.9], 'http://example.com/entity'), ([0.7, 0.8, 0.9], 'http://example.com/entity3'),
] ]
assert processor.vecstore.insert.call_count == 3 assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls): for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i] actual_call = processor.vecstore.insert.call_args_list[i]
@ -274,11 +272,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
uri_entity = EntityEmbeddings( uri_entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/uri_entity'), entity=Term(type=IRI, iri='http://example.com/uri_entity'),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
literal_entity = EntityEmbeddings( literal_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value='literal entity text'), entity=Term(type=LITERAL, value='literal entity text'),
vectors=[[0.4, 0.5, 0.6]] vector=[0.4, 0.5, 0.6]
) )
message.entities = [uri_entity, literal_entity] message.entities = [uri_entity, literal_entity]

View file

@ -24,16 +24,20 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create test entity embeddings # Create test entity embeddings (each entity has a single vector)
entity1 = EntityEmbeddings( entity1 = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True), entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3]
) )
entity2 = EntityEmbeddings( entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False), entity=Value(value="http://example.org/entity2", is_uri=True),
vectors=[[0.7, 0.8, 0.9]] vector=[0.4, 0.5, 0.6]
) )
message.entities = [entity1, entity2] entity3 = EntityEmbeddings(
entity=Value(value="entity3", is_uri=False),
vector=[0.7, 0.8, 0.9]
)
message.entities = [entity1, entity2, entity3]
return message return message
@ -122,27 +126,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True), entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]
# Mock index operations # Mock index operations
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']): with patch('uuid.uuid4', side_effect=['id1']):
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings(message)
# Verify index name and operations (with dimension suffix) # Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
processor.pinecone.Index.assert_called_with(expected_index_name) processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector # Verify upsert was called for the single vector
assert mock_index.upsert.call_count == 2 assert mock_index.upsert.call_count == 1
# Check first vector upsert # Check first vector upsert
first_call = mock_index.upsert.call_args_list[0] first_call = mock_index.upsert.call_args_list[0]
@ -190,7 +194,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False), entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]
@ -222,7 +226,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Value(value="", is_uri=False), entity=Value(value="", is_uri=False),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]
@ -244,7 +248,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False), entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]
@ -258,23 +262,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor): async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions to same index""" """Test storing graph embeddings with different vector dimensions"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user' message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( # Each entity has a single vector of different dimensions
entity=Value(value="test_entity", is_uri=False), entity1 = EntityEmbeddings(
vectors=[ entity=Value(value="entity1", is_uri=False),
[0.1, 0.2], # 2D vector vector=[0.1, 0.2] # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
) )
message.entities = [entity] entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
entity3 = EntityEmbeddings(
entity=Value(value="entity3", is_uri=False),
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.entities = [entity1, entity2, entity3]
# All vectors now use the same index (no dimension in name)
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
@ -322,7 +330,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False), entity=Value(value="test_entity", is_uri=False),
vectors=[] vector=[]
) )
message.entities = [entity] message.entities = [entity]
@ -344,7 +352,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False), entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]
@ -369,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings( entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False), entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
message.entities = [entity] message.entities = [entity]

View file

@ -70,7 +70,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity = MagicMock() mock_entity = MagicMock()
mock_entity.entity.type = IRI mock_entity.entity.type = IRI
mock_entity.entity.iri = 'test_entity' mock_entity.entity.iri = 'test_entity'
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions mock_entity.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
mock_message.entities = [mock_entity] mock_message.entities = [mock_entity]
@ -124,12 +124,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity1 = MagicMock() mock_entity1 = MagicMock()
mock_entity1.entity.type = IRI mock_entity1.entity.type = IRI
mock_entity1.entity.iri = 'entity_one' mock_entity1.entity.iri = 'entity_one'
mock_entity1.vectors = [[0.1, 0.2]] mock_entity1.vector = [0.1, 0.2]
mock_entity2 = MagicMock() mock_entity2 = MagicMock()
mock_entity2.entity.type = IRI mock_entity2.entity.type = IRI
mock_entity2.entity.iri = 'entity_two' mock_entity2.entity.iri = 'entity_two'
mock_entity2.vectors = [[0.3, 0.4]] mock_entity2.vector = [0.3, 0.4]
mock_message.entities = [mock_entity1, mock_entity2] mock_message.entities = [mock_entity1, mock_entity2]
@ -157,14 +157,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client): async def test_store_graph_embeddings_three_entities(self, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with multiple vectors per entity""" """Test storing graph embeddings with three entities"""
# Arrange # Arrange
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value.return_value = 'test-uuid' mock_uuid.uuid4.return_value.return_value = 'test-uuid'
config = { config = {
'store_uri': 'http://localhost:6333', 'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key', 'api_key': 'test-api-key',
@ -177,42 +177,48 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push) # Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {} processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with entity having multiple vectors # Create mock message with three entities
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'vector_user' mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection' mock_message.metadata.collection = 'vector_collection'
mock_entity = MagicMock() mock_entity1 = MagicMock()
mock_entity.entity.type = IRI mock_entity1.entity.type = IRI
mock_entity.entity.iri = 'multi_vector_entity' mock_entity1.entity.iri = 'entity_one'
mock_entity.vectors = [ mock_entity1.vector = [0.1, 0.2, 0.3]
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6], mock_entity2 = MagicMock()
[0.7, 0.8, 0.9] mock_entity2.entity.type = IRI
] mock_entity2.entity.iri = 'entity_two'
mock_entity2.vector = [0.4, 0.5, 0.6]
mock_message.entities = [mock_entity]
mock_entity3 = MagicMock()
mock_entity3.entity.type = IRI
mock_entity3.entity.iri = 'entity_three'
mock_entity3.vector = [0.7, 0.8, 0.9]
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
# Act # Act
await processor.store_graph_embeddings(mock_message) await processor.store_graph_embeddings(mock_message)
# Assert # Assert
# Should be called 3 times (once per vector) # Should be called 3 times (once per entity)
assert mock_qdrant_instance.upsert.call_count == 3 assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed # Verify all entities were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [ expected_data = [
[0.1, 0.2, 0.3], ([0.1, 0.2, 0.3], 'entity_one'),
[0.4, 0.5, 0.6], ([0.4, 0.5, 0.6], 'entity_two'),
[0.7, 0.8, 0.9] ([0.7, 0.8, 0.9], 'entity_three')
] ]
for i, call in enumerate(upsert_calls): for i, call in enumerate(upsert_calls):
point = call[1]['points'][0] point = call[1]['points'][0]
assert point.vector == expected_vectors[i] assert point.vector == expected_data[i][0]
assert point.payload['entity'] == 'multi_vector_entity' assert point.payload['entity'] == expected_data[i][1]
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client): async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
@ -238,11 +244,11 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity_empty = MagicMock() mock_entity_empty = MagicMock()
mock_entity_empty.entity.type = LITERAL mock_entity_empty.entity.type = LITERAL
mock_entity_empty.entity.value = "" # Empty string mock_entity_empty.entity.value = "" # Empty string
mock_entity_empty.vectors = [[0.1, 0.2]] mock_entity_empty.vector = [0.1, 0.2]
mock_entity_none = MagicMock() mock_entity_none = MagicMock()
mock_entity_none.entity = None # None entity mock_entity_none.entity = None # None entity
mock_entity_none.vectors = [[0.3, 0.4]] mock_entity_none.vector = [0.3, 0.4]
mock_message.entities = [mock_entity_empty, mock_entity_none] mock_message.entities = [mock_entity_empty, mock_entity_none]

View file

@ -197,7 +197,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='customer_id', index_name='customer_id',
index_value=['CUST001'], index_value=['CUST001'],
text='CUST001', text='CUST001',
vectors=[[0.1, 0.2, 0.3]] vector=[0.1, 0.2, 0.3]
) )
embeddings_msg = RowEmbeddings( embeddings_msg = RowEmbeddings(
@ -227,8 +227,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client): async def test_on_embeddings_single_vector(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with multiple vectors""" """Test processing embeddings with a single vector"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
@ -250,12 +250,12 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
metadata.collection = 'test_collection' metadata.collection = 'test_collection'
metadata.id = 'doc-123' metadata.id = 'doc-123'
# Embedding with multiple vectors # Embedding with a single 6D vector
embedding = RowIndexEmbedding( embedding = RowIndexEmbedding(
index_name='name', index_name='name',
index_value=['John Doe'], index_value=['John Doe'],
text='John Doe', text='John Doe',
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
) )
embeddings_msg = RowEmbeddings( embeddings_msg = RowEmbeddings(
@ -269,8 +269,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should be called 3 times (once per vector) # Should be called once for the single embedding
assert mock_qdrant_instance.upsert.call_count == 3 assert mock_qdrant_instance.upsert.call_count == 1
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client): async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
@ -299,7 +299,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='id', index_name='id',
index_value=['123'], index_value=['123'],
text='123', text='123',
vectors=[] # Empty vectors vector=[] # Empty vector
) )
embeddings_msg = RowEmbeddings( embeddings_msg = RowEmbeddings(
@ -342,7 +342,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='id', index_name='id',
index_value=['123'], index_value=['123'],
text='123', text='123',
vectors=[[0.1, 0.2]] vector=[0.1, 0.2]
) )
embeddings_msg = RowEmbeddings( embeddings_msg = RowEmbeddings(

View file

@ -612,12 +612,12 @@ class AsyncFlowInstance:
print(f"{entity['name']}: {entity['score']}") print(f"{entity['name']}: {entity['score']}")
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request_data = { request_data = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -810,12 +810,12 @@ class AsyncFlowInstance:
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})") print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request_data = { request_data = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search""" """Query graph embeddings for semantic search"""
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
limit: int = 10, **kwargs limit: int = 10, **kwargs
): ):
"""Query row embeddings for semantic search on structured data""" """Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -602,13 +602,13 @@ class FlowInstance:
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
# Query graph embeddings for semantic search # Query graph embeddings for semantic search
input = { input = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -648,13 +648,13 @@ class FlowInstance:
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
# Query document embeddings for semantic search # Query document embeddings for semantic search
input = { input = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -1362,13 +1362,13 @@ class FlowInstance:
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
# Query row embeddings for semantic search # Query row embeddings for semantic search
input = { input = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -649,12 +649,12 @@ class SocketFlowInstance:
) )
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -698,12 +698,12 @@ class SocketFlowInstance:
# results contains {"chunk_ids": ["doc1/p0/c0", ...]} # results contains {"chunk_ids": ["doc1/p0/c0", ...]}
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -936,12 +936,12 @@ class SocketFlowInstance:
) )
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DocumentEmbeddingsClient(RequestResponse): class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph", async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30): collection="default", timeout=30):
resp = await self.request( resp = await self.request(
DocumentEmbeddingsRequest( DocumentEmbeddingsRequest(
vectors = vectors, vector = vector,
limit = limit, limit = limit,
user = user, user = user,
collection = collection collection = collection
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
return resp.chunk_ids # Return ChunkMatch objects with chunk_id and score
return resp.chunks
class DocumentEmbeddingsClientSpec(RequestResponseSpec): class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request) docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...") logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None) r = DocumentEmbeddingsResponse(chunks=docs, error=None)
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed") logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error", type = "document-embeddings-query-error",
message = str(e), message = str(e),
), ),
chunk_ids=[], chunks=[],
) )
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})

View file

@ -19,12 +19,12 @@ def to_value(x):
return Literal(x.value or x.iri) return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse): class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph", async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30): collection="default", timeout=30):
resp = await self.request( resp = await self.request(
GraphEmbeddingsRequest( GraphEmbeddingsRequest(
vectors = vectors, vector = vector,
limit = limit, limit = limit,
user = user, user = user,
collection = collection collection = collection
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
return [ # Return EntityMatch objects with entity and score
to_value(v) return resp.entities
for v in resp.entities
]
class GraphEmbeddingsClientSpec(RequestResponseSpec): class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse): class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query( async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default", self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600 index_name=None, limit=10, timeout=600
): ):
request = RowEmbeddingsRequest( request = RowEmbeddingsRequest(
vectors=vectors, vector=vector,
schema_name=schema_name, schema_name=schema_name,
user=user, user=user,
collection=collection, collection=collection,

View file

@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
) )
def request( def request(
self, vectors, user="trustgraph", collection="default", self, vector, user="trustgraph", collection="default",
limit=10, timeout=300 limit=10, timeout=300
): ):
return self.call( return self.call(
user=user, collection=collection, user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout vector=vector, limit=limit, timeout=timeout
).chunks ).chunks

View file

@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
) )
def request( def request(
self, vectors, user="trustgraph", collection="default", self, vector, user="trustgraph", collection="default",
limit=10, timeout=300 limit=10, timeout=300
): ):
return self.call( return self.call(
user=user, collection=collection, user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout vector=vector, limit=limit, timeout=timeout
).entities ).entities

View file

@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
) )
def request( def request(
self, vectors, schema_name, user="trustgraph", collection="default", self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=300 index_name=None, limit=10, timeout=300
): ):
kwargs = dict( kwargs = dict(
user=user, collection=collection, user=user, collection=collection,
vectors=vectors, schema_name=schema_name, vector=vector, schema_name=schema_name,
limit=limit, timeout=timeout limit=limit, timeout=timeout
) )
if index_name: if index_name:

View file

@ -10,18 +10,18 @@ from .primitives import ValueTranslator
class DocumentEmbeddingsRequestTranslator(MessageTranslator): class DocumentEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsRequest schema objects""" """Translator for DocumentEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
return DocumentEmbeddingsRequest( return DocumentEmbeddingsRequest(
vectors=data["vectors"], vector=data["vector"],
limit=int(data.get("limit", 10)), limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"), user=data.get("user", "trustgraph"),
collection=data.get("collection", "default") collection=data.get("collection", "default")
) )
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
return { return {
"vectors": obj.vectors, "vector": obj.vector,
"limit": obj.limit, "limit": obj.limit,
"user": obj.user, "user": obj.user,
"collection": obj.collection "collection": obj.collection
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
class DocumentEmbeddingsResponseTranslator(MessageTranslator): class DocumentEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsResponse schema objects""" """Translator for DocumentEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed") raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.chunk_ids is not None: if obj.chunks is not None:
result["chunk_ids"] = list(obj.chunk_ids) result["chunks"] = [
{
"chunk_id": chunk.chunk_id,
"score": chunk.score
}
for chunk in obj.chunks
]
return result return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)""" """Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True return self.from_pulsar(obj), True
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
class GraphEmbeddingsRequestTranslator(MessageTranslator): class GraphEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsRequest schema objects""" """Translator for GraphEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
return GraphEmbeddingsRequest( return GraphEmbeddingsRequest(
vectors=data["vectors"], vector=data["vector"],
limit=int(data.get("limit", 10)), limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"), user=data.get("user", "trustgraph"),
collection=data.get("collection", "default") collection=data.get("collection", "default")
) )
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
return { return {
"vectors": obj.vectors, "vector": obj.vector,
"limit": obj.limit, "limit": obj.limit,
"user": obj.user, "user": obj.user,
"collection": obj.collection "collection": obj.collection
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
class GraphEmbeddingsResponseTranslator(MessageTranslator): class GraphEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsResponse schema objects""" """Translator for GraphEmbeddingsResponse schema objects"""
def __init__(self): def __init__(self):
self.value_translator = ValueTranslator() self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed") raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.entities is not None: if obj.entities is not None:
result["entities"] = [ result["entities"] = [
self.value_translator.from_pulsar(entity) {
for entity in obj.entities "entity": self.value_translator.from_pulsar(match.entity),
"score": match.score
}
for match in obj.entities
] ]
return result return result
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)""" """Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True return self.from_pulsar(obj), True
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
return RowEmbeddingsRequest( return RowEmbeddingsRequest(
vectors=data["vectors"], vector=data["vector"],
limit=int(data.get("limit", 10)), limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"), user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"), collection=data.get("collection", "default"),
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
result = { result = {
"vectors": obj.vectors, "vector": obj.vector,
"limit": obj.limit, "limit": obj.limit,
"user": obj.user, "user": obj.user,
"collection": obj.collection, "collection": obj.collection,

View file

@ -11,7 +11,7 @@ from ..core.topic import topic
@dataclass @dataclass
class EntityEmbeddings: class EntityEmbeddings:
entity: Term | None = None entity: Term | None = None
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
# Provenance: which chunk this embedding was derived from # Provenance: which chunk this embedding was derived from
chunk_id: str = "" chunk_id: str = ""
@ -28,7 +28,7 @@ class GraphEmbeddings:
@dataclass @dataclass
class ChunkEmbeddings: class ChunkEmbeddings:
chunk_id: str = "" chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
# This is a 'batching' mechanism for the above data # This is a 'batching' mechanism for the above data
@dataclass @dataclass
@ -44,7 +44,7 @@ class DocumentEmbeddings:
@dataclass @dataclass
class ObjectEmbeddings: class ObjectEmbeddings:
metadata: Metadata | None = None metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
name: str = "" name: str = ""
key_name: str = "" key_name: str = ""
id: str = "" id: str = ""
@ -56,7 +56,7 @@ class ObjectEmbeddings:
@dataclass @dataclass
class StructuredObjectEmbedding: class StructuredObjectEmbedding:
metadata: Metadata | None = None metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
schema_name: str = "" schema_name: str = ""
object_id: str = "" # Primary key value object_id: str = "" # Primary key value
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
@ -72,7 +72,7 @@ class RowIndexEmbedding:
index_name: str = "" # The indexed field name(s) index_name: str = "" # The indexed field name(s)
index_value: list[str] = field(default_factory=list) # The field value(s) index_value: list[str] = field(default_factory=list) # The field value(s)
text: str = "" # Text that was embedded text: str = "" # Text that was embedded
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
@dataclass @dataclass
class RowEmbeddings: class RowEmbeddings:

View file

@ -34,7 +34,7 @@ class EmbeddingsRequest:
@dataclass @dataclass
class EmbeddingsResponse: class EmbeddingsResponse:
error: Error | None = None error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list) vectors: list[list[float]] = field(default_factory=list)
############################################################################ ############################################################################

View file

@ -9,15 +9,21 @@ from ..core.topic import topic
@dataclass @dataclass
class GraphEmbeddingsRequest: class GraphEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
limit: int = 0 limit: int = 0
user: str = "" user: str = ""
collection: str = "" collection: str = ""
@dataclass
class EntityMatch:
"""A matching entity from a semantic search with similarity score"""
entity: Term | None = None
score: float = 0.0
@dataclass @dataclass
class GraphEmbeddingsResponse: class GraphEmbeddingsResponse:
error: Error | None = None error: Error | None = None
entities: list[Term] = field(default_factory=list) entities: list[EntityMatch] = field(default_factory=list)
############################################################################ ############################################################################
@ -44,15 +50,21 @@ class TriplesQueryResponse:
@dataclass @dataclass
class DocumentEmbeddingsRequest: class DocumentEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
limit: int = 0 limit: int = 0
user: str = "" user: str = ""
collection: str = "" collection: str = ""
@dataclass
class ChunkMatch:
"""A matching chunk from a semantic search with similarity score"""
chunk_id: str = ""
score: float = 0.0
@dataclass @dataclass
class DocumentEmbeddingsResponse: class DocumentEmbeddingsResponse:
error: Error | None = None error: Error | None = None
chunk_ids: list[str] = field(default_factory=list) chunks: list[ChunkMatch] = field(default_factory=list)
document_embeddings_request_queue = topic( document_embeddings_request_queue = topic(
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
@ -76,7 +88,7 @@ class RowIndexMatch:
@dataclass @dataclass
class RowEmbeddingsRequest: class RowEmbeddingsRequest:
"""Request for row embeddings semantic search""" """Request for row embeddings semantic search"""
vectors: list[list[float]] = field(default_factory=list) # Query vectors vector: list[float] = field(default_factory=list) # Query vector
limit: int = 10 # Max results to return limit: int = 10 # Max results to return
user: str = "" # User/keyspace user: str = "" # User/keyspace
collection: str = "" # Collection name collection: str = "" # Collection name

View file

@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl:
query_text = arguments.get("query") query_text = arguments.get("query")
all_vectors = await embeddings_client.embed([query_text]) all_vectors = await embeddings_client.embed([query_text])
vectors = all_vectors[0] if all_vectors else [] vector = all_vectors[0] if all_vectors else []
# Now query row embeddings # Now query row embeddings
client = self.context("row-embeddings-query-request") client = self.context("row-embeddings-query-request")
@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl:
user = getattr(client, '_current_user', self.user or "trustgraph") user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query( matches = await client.row_embeddings_query(
vectors=vectors, vector=vector,
schema_name=self.schema_name, schema_name=self.schema_name,
user=user, user=user,
collection=self.collection or "default", collection=self.collection or "default",

View file

@ -66,13 +66,13 @@ class Processor(FlowProcessor):
) )
) )
# vectors[0] is the vector set for the first (only) text # vectors[0] is the vector for the first (only) text
vectors = resp.vectors[0] if resp.vectors else [] vector = resp.vectors[0] if resp.vectors else []
embeds = [ embeds = [
ChunkEmbeddings( ChunkEmbeddings(
chunk_id=v.document_id, chunk_id=v.document_id,
vectors=vectors, vector=vector,
) )
] ]

View file

@ -59,11 +59,8 @@ class Processor(EmbeddingsService):
# FastEmbed processes the full batch efficiently # FastEmbed processes the full batch efficiently
vecs = list(self.embeddings.embed(texts)) vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text # Return list of vectors, one per input text
return [ return [v.tolist() for v in vecs]
[v.tolist()]
for v in vecs
]
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -72,10 +72,10 @@ class Processor(FlowProcessor):
entities = [ entities = [
EntityEmbeddings( EntityEmbeddings(
entity=entity.entity, entity=entity.entity,
vectors=vectors, # Vector set for this entity vector=vector,
chunk_id=entity.chunk_id, # Provenance: source chunk chunk_id=entity.chunk_id, # Provenance: source chunk
) )
for entity, vectors in zip(v.entities, all_vectors) for entity, vector in zip(v.entities, all_vectors)
] ]
# Send in batches to avoid oversized messages # Send in batches to avoid oversized messages

View file

@ -43,11 +43,8 @@ class Processor(EmbeddingsService):
input = texts input = texts
) )
# Return list of vector sets, one per input text # Return list of vectors, one per input text
return [ return list(embeds.embeddings)
[embedding]
for embedding in embeds.embeddings
]
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
all_vectors = await flow("embeddings-request").embed(texts=texts) all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results with metadata # Pair results with metadata
for text, (index_name, index_value), vectors in zip( for text, (index_name, index_value), vector in zip(
texts, metadata, all_vectors texts, metadata, all_vectors
): ):
embeddings_list.append( embeddings_list.append(
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
index_name=index_name, index_name=index_name,
index_value=index_value, index_value=index_value,
text=text, text=text,
vectors=vectors # Vector set for this text vector=vector
) )
) )

View file

@ -7,7 +7,7 @@ of chunk_ids
import logging import logging
from .... direct.milvus_doc_embeddings import DocVectors from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddingsResponse from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error from .... schema import Error
from .... base import DocumentEmbeddingsQueryService from .... base import DocumentEmbeddingsQueryService
@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
chunk_ids = [] resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for vec in msg.vectors: chunks = []
for r in resp:
chunk_id = r["entity"]["chunk_id"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
resp = self.vecstore.search( return chunks
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for r in resp:
chunk_id = r["entity"]["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
except Exception as e: except Exception as e:

View file

@ -11,6 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import ChunkMatch
from .... base import DocumentEmbeddingsQueryService from .... base import DocumentEmbeddingsQueryService
# Module logger # Module logger
@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
chunk_ids = [] dim = len(vec)
for vec in msg.vectors: # Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
dim = len(vec) # Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
# Use dimension suffix in index name index = self.pinecone.Index(index_name)
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - skip if not results = index.query(
if not self.pinecone.has_index(index_name): vector=vec,
logger.info(f"Index {index_name} does not exist, skipping this vector") top_k=msg.limit,
continue include_values=False,
include_metadata=True
)
index = self.pinecone.Index(index_name) chunks = []
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
results = index.query( return chunks
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
except Exception as e: except Exception as e:

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error from .... schema import Error
from .... base import DocumentEmbeddingsQueryService from .... base import DocumentEmbeddingsQueryService
@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService):
try: try:
chunk_ids = [] vec = msg.vector
if not vec:
return []
for vec in msg.vectors: # Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Use dimension suffix in collection name # Check if collection exists - return empty if not
dim = len(vec) if not self.collection_exists(collection):
collection = f"d_{msg.user}_{msg.collection}_{dim}" logger.info(f"Collection {collection} does not exist, returning empty results")
return []
# Check if collection exists - return empty if not search_result = self.qdrant.query_points(
if not self.collection_exists(collection): collection_name=collection,
logger.info(f"Collection {collection} does not exist, returning empty results") query=vec,
continue limit=msg.limit,
with_payload=True,
).points
search_result = self.qdrant.query_points( chunks = []
collection_name=collection, for r in search_result:
query=vec, chunk_id = r.payload["chunk_id"]
limit=msg.limit, score = r.score if hasattr(r, 'score') else 0.0
with_payload=True, chunks.append(ChunkMatch(
).points chunk_id=chunk_id,
score=score,
))
for r in search_result: return chunks
chunk_id = r.payload["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
except Exception as e: except Exception as e:

View file

@ -7,7 +7,7 @@ entities
import logging import logging
from .... direct.milvus_graph_embeddings import EntityVectors from .... direct.milvus_graph_embeddings import EntityVectors
from .... schema import GraphEmbeddingsResponse from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService from .... base import GraphEmbeddingsQueryService
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
try: try:
entity_set = set() vec = msg.vector
entities = [] if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
for vec in msg.vectors: resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
resp = self.vecstore.search( entity_set = set()
vec, entities = []
msg.user,
msg.collection,
limit=msg.limit * 2
)
for r in resp: for r in resp:
ent = r["entity"]["entity"] ent = r["entity"]["entity"]
# Milvus returns distance, convert to similarity score
# De-dupe entities distance = r.get("distance", 0.0)
if ent not in entity_set: score = 1.0 - distance if distance else 0.0
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit # De-dupe entities, keep highest score
if len(entity_set) >= msg.limit: break if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit # Keep adding entities until limit
if len(entity_set) >= msg.limit: break if len(entities) >= msg.limit:
break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
logger.debug("Send response...") logger.debug("Send response...")
return entities return entities

View file

@ -11,7 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import GraphEmbeddingsResponse from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService from .... base import GraphEmbeddingsQueryService
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
entity_set = set() entity_set = set()
entities = [] entities = []
for vec in msg.vectors: for r in results.matches:
ent = r.metadata["entity"]
score = r.score if hasattr(r, 'score') else 0.0
dim = len(vec) # De-dupe entities, keep highest score
if ent not in entity_set:
# Use dimension suffix in index name entity_set.add(ent)
index_name = f"t-{msg.user}-{msg.collection}-{dim}" entities.append(EntityMatch(
entity=self.create_value(ent),
# Check if index exists - skip if not score=score,
if not self.pinecone.has_index(index_name): ))
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# Keep adding entities until limit # Keep adding entities until limit
if len(entity_set) >= msg.limit: break if len(entities) >= msg.limit:
break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
return entities return entities

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService from .... base import GraphEmbeddingsQueryService
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist")
return []
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
entity_set = set() entity_set = set()
entities = [] entities = []
for vec in msg.vectors: for r in search_result:
ent = r.payload["entity"]
score = r.score if hasattr(r, 'score') else 0.0
# Use dimension suffix in collection name # De-dupe entities, keep highest score
dim = len(vec) if ent not in entity_set:
collection = f"t_{msg.user}_{msg.collection}_{dim}" entity_set.add(ent)
entities.append(EntityMatch(
# Check if collection exists - return empty if not entity=self.create_value(ent),
if not self.collection_exists(collection): score=score,
logger.info(f"Collection {collection} does not exist, skipping this vector") ))
continue
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
for r in search_result:
ent = r.payload["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# Keep adding entities until limit # Keep adding entities until limit
if len(entity_set) >= msg.limit: break if len(entities) >= msg.limit:
break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
logger.debug("Send response...") logger.debug("Send response...")
return entities return entities

View file

@ -93,7 +93,9 @@ class Processor(FlowProcessor):
async def query_row_embeddings(self, request: RowEmbeddingsRequest): async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query""" """Execute row embeddings query"""
matches = [] vec = request.vector
if not vec:
return []
# Find the collection for this user/collection/schema # Find the collection for this user/collection/schema
qdrant_collection = self.find_collection( qdrant_collection = self.find_collection(
@ -105,47 +107,47 @@ class Processor(FlowProcessor):
f"No Qdrant collection found for " f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}" f"{request.user}/{request.collection}/{request.schema_name}"
) )
return []
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
matches = []
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
return matches return matches
for vec in request.vectors: except Exception as e:
try: logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
# Build optional filter for index_name raise
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
"""Handle incoming query request""" """Handle incoming query request"""

View file

@ -37,26 +37,26 @@ class Query:
vectors = await self.get_vector(query) vectors = await self.get_vector(query)
if self.verbose: if self.verbose:
logger.debug("Getting chunk_ids from embeddings store...") logger.debug("Getting chunks from embeddings store...")
# Get chunk_ids from embeddings store # Get chunk matches from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query( chunk_matches = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit, vector=vectors, limit=self.doc_limit,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
) )
if self.verbose: if self.verbose:
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...") logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
# Fetch chunk content from Garage # Fetch chunk content from Garage
docs = [] docs = []
for chunk_id in chunk_ids: for match in chunk_matches:
if chunk_id: if match.chunk_id:
try: try:
content = await self.rag.fetch_chunk(chunk_id, self.user) content = await self.rag.fetch_chunk(match.chunk_id, self.user)
docs.append(content) docs.append(content)
except Exception as e: except Exception as e:
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}") logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
if self.verbose: if self.verbose:
logger.debug("Documents fetched:") logger.debug("Documents fetched:")

View file

@ -87,14 +87,14 @@ class Query:
if self.verbose: if self.verbose:
logger.debug("Getting entities...") logger.debug("Getting entities...")
entities = await self.rag.graph_embeddings_client.query( entity_matches = await self.rag.graph_embeddings_client.query(
vectors=vectors, limit=self.entity_limit, vector=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
) )
entities = [ entities = [
str(e) str(e.entity)
for e in entities for e in entity_matches
] ]
if self.verbose: if self.verbose:

View file

@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "": if chunk_id == "":
continue continue
for vec in emb.vectors: vec = emb.vector
if vec:
self.vecstore.insert( self.vecstore.insert(
vec, chunk_id, vec, chunk_id,
message.metadata.user, message.metadata.user,

View file

@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "": if chunk_id == "":
continue continue
for vec in emb.vectors: vec = emb.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation # Create index name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
index_name = ( index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
) )
# Lazily create index if it doesn't exist (but only if authorized in config) # Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name): if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim) self.create_index(index_name, dim)
index = self.pinecone.Index(index_name) index = self.pinecone.Index(index_name)
# Generate unique ID for each vector # Generate unique ID for each vector
vector_id = str(uuid.uuid4()) vector_id = str(uuid.uuid4())
records = [ records = [
{ {
"id": vector_id, "id": vector_id,
"values": vec, "values": vec,
"metadata": { "chunk_id": chunk_id }, "metadata": { "chunk_id": chunk_id },
} }
] ]
index.upsert( index.upsert(
vectors = records, vectors = records,
) )
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "": if chunk_id == "":
continue continue
for vec in emb.vectors: vec = emb.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation # Create collection name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
collection = ( collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
) )
# Lazily create collection if it doesn't exist (but only if authorized in config) # Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection): if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection( self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
collection_name=collection, collection_name=collection,
points=[ vectors_config=VectorParams(
PointStruct( size=dim,
id=str(uuid.uuid4()), distance=Distance.COSINE
vector=vec, )
payload={
"chunk_id": chunk_id,
}
)
]
) )
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"chunk_id": chunk_id,
}
)
]
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
entity_value = get_term_value(entity.entity) entity_value = get_term_value(entity.entity)
if entity_value != "" and entity_value is not None: if entity_value != "" and entity_value is not None:
for vec in entity.vectors: vec = entity.vector
if vec:
self.vecstore.insert( self.vecstore.insert(
vec, entity_value, vec, entity_value,
message.metadata.user, message.metadata.user,

View file

@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None: if entity_value == "" or entity_value is None:
continue continue
for vec in entity.vectors: vec = entity.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation # Create index name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
index_name = ( index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
) )
# Lazily create index if it doesn't exist (but only if authorized in config) # Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name): if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim) self.create_index(index_name, dim)
index = self.pinecone.Index(index_name) index = self.pinecone.Index(index_name)
# Generate unique ID for each vector # Generate unique ID for each vector
vector_id = str(uuid.uuid4()) vector_id = str(uuid.uuid4())
metadata = {"entity": entity_value} metadata = {"entity": entity_value}
if entity.chunk_id: if entity.chunk_id:
metadata["chunk_id"] = entity.chunk_id metadata["chunk_id"] = entity.chunk_id
records = [ records = [
{ {
"id": vector_id, "id": vector_id,
"values": vec, "values": vec,
"metadata": metadata, "metadata": metadata,
} }
] ]
index.upsert( index.upsert(
vectors = records, vectors = records,
) )
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None: if entity_value == "" or entity_value is None:
continue continue
for vec in entity.vectors: vec = entity.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation # Create collection name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
collection = ( collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
) )
# Lazily create collection if it doesn't exist (but only if authorized in config) # Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection): if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection( self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
collection_name=collection, collection_name=collection,
points=[ vectors_config=VectorParams(
PointStruct( size=dim,
id=str(uuid.uuid4()), distance=Distance.COSINE
vector=vec, )
payload=payload,
)
]
) )
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload=payload,
)
]
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
qdrant_collection = None qdrant_collection = None
for row_emb in embeddings.embeddings: for row_emb in embeddings.embeddings:
if not row_emb.vectors: vector = row_emb.vector
if not vector:
logger.warning( logger.warning(
f"No vectors for index {row_emb.index_name} - skipping" f"No vector for index {row_emb.index_name} - skipping"
) )
continue continue
# Use first vector (there may be multiple from different models) dimension = len(vector)
for vector in row_emb.vectors:
dimension = len(vector)
# Create/get collection name (lazily on first vector) # Create/get collection name (lazily on first vector)
if qdrant_collection is None: if qdrant_collection is None:
qdrant_collection = self.get_collection_name( qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
) )
embeddings_written += 1 self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")