Embeddings API scores (#671)

- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
cybermaggedon 2026-03-09 10:53:44 +00:00 committed by GitHub
parent 4fa7cc7d7c
commit f2ae0e8623
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 1339 additions and 1292 deletions

View file

@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services
import pytest
from unittest.mock import MagicMock
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error
from trustgraph.messaging.translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator,
DocumentEmbeddingsResponseTranslator
@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract:
"""Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request
request = DocumentEmbeddingsRequest(
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
# Verify all expected fields exist
assert hasattr(request, 'vectors')
assert hasattr(request, 'vector')
assert hasattr(request, 'limit')
assert hasattr(request, 'user')
assert hasattr(request, 'collection')
# Verify field values
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert request.vector == [0.1, 0.2, 0.3]
assert request.limit == 10
assert request.user == "test_user"
assert request.collection == "test_collection"
@ -43,7 +43,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2], [0.3, 0.4]],
"vector": [0.1, 0.2, 0.3, 0.4],
"limit": 5,
"user": "custom_user",
"collection": "custom_collection"
@ -52,7 +52,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
assert result.vector == [0.1, 0.2, 0.3, 0.4]
assert result.limit == 5
assert result.user == "custom_user"
assert result.collection == "custom_collection"
@ -62,14 +62,14 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2]]
"vector": [0.1, 0.2]
# No limit, user, or collection provided
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2]]
assert result.vector == [0.1, 0.2]
assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default
@ -79,7 +79,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator()
request = DocumentEmbeddingsRequest(
vectors=[[0.5, 0.6]],
vector=[0.5, 0.6],
limit=20,
user="test_user",
collection="test_collection"
@ -88,7 +88,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.from_pulsar(request)
assert isinstance(result, dict)
assert result["vectors"] == [[0.5, 0.6]]
assert result["vector"] == [0.5, 0.6]
assert result["limit"] == 20
assert result["user"] == "test_user"
assert result["collection"] == "test_collection"
@ -99,19 +99,25 @@ class TestDocumentEmbeddingsResponseContract:
def test_response_schema_fields(self):
"""Test that DocumentEmbeddingsResponse has expected fields"""
# Create a response with chunk_ids
# Create a response with chunks
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["chunk1", "chunk2", "chunk3"]
chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8),
ChunkMatch(chunk_id="chunk3", score=0.7)
]
)
# Verify all expected fields exist
assert hasattr(response, 'error')
assert hasattr(response, 'chunk_ids')
assert hasattr(response, 'chunks')
# Verify field values
assert response.error is None
assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"]
assert len(response.chunks) == 3
assert response.chunks[0].chunk_id == "chunk1"
assert response.chunks[0].score == 0.9
def test_response_schema_with_error(self):
"""Test response schema with error"""
@ -122,53 +128,59 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse(
error=error,
chunk_ids=[]
chunks=[]
)
assert response.error == error
assert response.chunk_ids == []
assert response.chunks == []
def test_response_translator_from_pulsar_with_chunk_ids(self):
"""Test response translator converts Pulsar schema with chunk_ids to dict"""
def test_response_translator_from_pulsar_with_chunks(self):
"""Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"]
chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85),
ChunkMatch(chunk_id="doc3/c3", score=0.75)
]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunk_ids" in result
assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"]
assert "chunks" in result
assert len(result["chunks"]) == 3
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
assert result["chunks"][0]["score"] == 0.95
def test_response_translator_from_pulsar_with_empty_chunk_ids(self):
"""Test response translator handles empty chunk_ids list"""
def test_response_translator_from_pulsar_with_empty_chunks(self):
"""Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=[]
chunks=[]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunk_ids" in result
assert result["chunk_ids"] == []
assert "chunks" in result
assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunk_ids(self):
"""Test response translator handles None chunk_ids"""
def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunk_ids = None
response.chunks = None
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunk_ids" not in result or result.get("chunk_ids") is None
assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag"""
@ -176,14 +188,18 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["chunk1", "chunk2"]
chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8)
]
)
result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict)
assert "chunk_ids" in result
assert result["chunk_ids"] == ["chunk1", "chunk2"]
assert "chunks" in result
assert len(result["chunks"]) == 2
assert result["chunks"][0]["chunk_id"] == "chunk1"
assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self):
@ -191,7 +207,7 @@ class TestDocumentEmbeddingsResponseContract:
translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunk_ids": ["test"]})
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
class TestDocumentEmbeddingsMessageCompatibility:
@ -201,7 +217,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
"""Test complete request-response flow maintains data integrity"""
# Create request
request_data = {
"vectors": [[0.1, 0.2, 0.3]],
"vector": [0.1, 0.2, 0.3],
"limit": 5,
"user": "test_user",
"collection": "test_collection"
@ -214,7 +230,10 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Simulate service processing and creating response
response = DocumentEmbeddingsResponse(
error=None,
chunk_ids=["doc1/c1", "doc2/c2"]
chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85)
]
)
# Convert response back to dict
@ -224,8 +243,8 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict)
assert "chunk_ids" in response_data
assert len(response_data["chunk_ids"]) == 2
assert "chunks" in response_data
assert len(response_data["chunks"]) == 2
def test_error_response_flow(self):
"""Test error response flow"""
@ -237,7 +256,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
response = DocumentEmbeddingsResponse(
error=error,
chunk_ids=[]
chunks=[]
)
# Convert response to dict
@ -246,6 +265,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify error handling
assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunk_ids
assert "chunk_ids" in response_data
assert response_data["chunk_ids"] == []
# The translator doesn't include error in the dict, only chunks
assert "chunks" in response_data
assert response_data["chunks"] == []

View file

@ -285,11 +285,11 @@ class TestStructuredEmbeddingsContracts:
collection="test_collection",
metadata=[]
)
# Act
embedding = StructuredObjectEmbedding(
metadata=metadata,
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3],
schema_name="customer_records",
object_id="customer_123",
field_embeddings={
@ -301,7 +301,7 @@ class TestStructuredEmbeddingsContracts:
# Assert
assert embedding.schema_name == "customer_records"
assert embedding.object_id == "customer_123"
assert len(embedding.vectors) == 2
assert len(embedding.vector) == 3
assert len(embedding.field_embeddings) == 2
assert "name" in embedding.field_embeddings

View file

@ -9,6 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
# Sample chunk content for testing - maps chunk_id to content
@ -39,10 +40,14 @@ class TestDocumentRagIntegration:
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk IDs"""
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
# Now returns chunk_ids instead of actual content
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
# Returns ChunkMatch objects with chunk_id and score
client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client
@pytest.fixture
@ -97,7 +102,7 @@ class TestDocumentRagIntegration:
mock_embeddings_client.embed.assert_called_once_with([query])
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
@ -298,7 +303,7 @@ class TestDocumentRagIntegration:
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "chunk_ids" in log_messages.lower()
assert "chunks" in log_messages.lower()
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages
@ -307,9 +312,9 @@ class TestDocumentRagIntegration:
async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large chunk_id set (100 chunks)
large_chunk_ids = [f"doc/c{i}" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_ids
# Arrange - Mock large chunk match set (100 chunks)
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_matches
# Act
import time

View file

@ -8,6 +8,7 @@ response delivery through the complete pipeline.
import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -36,10 +37,14 @@ class TestDocumentRagStreaming:
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk IDs"""
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
# Now returns chunk_ids instead of actual content
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
# Returns ChunkMatch objects with chunk_id and score
client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client
@pytest.fixture

View file

@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
@pytest.mark.integration
@ -35,9 +36,9 @@ class TestGraphRagIntegration:
"""Mock graph embeddings client that returns realistic entities"""
client = AsyncMock()
client.query.return_value = [
"http://trustgraph.ai/e/machine-learning",
"http://trustgraph.ai/e/artificial-intelligence",
"http://trustgraph.ai/e/neural-networks"
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
]
return client
@ -130,7 +131,7 @@ class TestGraphRagIntegration:
# 2. Should query graph embeddings to find relevant entities
mock_graph_embeddings_client.query.assert_called_once()
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection

View file

@ -8,6 +8,7 @@ response delivery through the complete pipeline.
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@ -33,7 +34,7 @@ class TestGraphRagStreaming:
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = [
"http://trustgraph.ai/e/machine-learning",
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
]
return client

View file

@ -411,7 +411,7 @@ class TestKnowledgeGraphPipelineIntegration:
entities=[
EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/entity"),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
]
)

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
class TestGraphRagStreamingProtocol:
@ -25,7 +26,10 @@ class TestGraphRagStreamingProtocol:
def mock_graph_embeddings_client(self):
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = ["entity1", "entity2"]
client.query.return_value = [
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
]
return client
@pytest.fixture
@ -202,9 +206,12 @@ class TestDocumentRagStreamingProtocol:
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk IDs"""
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
client.query.return_value = ["doc/c1", "doc/c2"]
client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90)
]
return client
@pytest.fixture

View file

@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
# Mock the request method
client.request = AsyncMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = await client.query(
vectors=vectors,
vector=vector,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vectors == vectors
assert call_args.vector == vector
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(RuntimeError, match="Database connection failed"):
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = []
mock_response.chunks = []
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
assert result == []
@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["test_chunk"]
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
client.request.assert_called_once()
@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["chunk1"]
mock_response.chunks = ["chunk1"]
client.request = AsyncMock(return_value=mock_response)
# Act
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
timeout=60
)
# Assert
assert client.request.call_args[1]["timeout"] == 60
@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunk_ids = ["test_chunk"]
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
mock_logger.debug.assert_called_once()
assert "Document embeddings response" in str(mock_logger.debug.call_args)

View file

@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = client.request(
vectors=vectors,
vector=vector,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vectors=vectors,
vector=vector,
limit=10,
timeout=300
)
@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["test_chunk"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3]]
vector = [0.1, 0.2, 0.3]
# Act
result = client.request(vectors=vectors)
result = client.request(vector=vector)
# Assert
assert result == ["test_chunk"]
client.call.assert_called_once_with(
user="trustgraph",
collection="default",
vectors=vectors,
vector=vector,
limit=10,
timeout=300
)
@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = []
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
result = client.request(vector=[0.1, 0.2, 0.3])
# Assert
assert result == []
@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = None
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
result = client.request(vector=[0.1, 0.2, 0.3])
# Assert
assert result is None
@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["chunk1"]
client.call = MagicMock(return_value=mock_response)
# Act
client.request(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
timeout=600
)
# Assert
assert client.call.call_args[1]["timeout"] == 600

View file

@ -98,7 +98,7 @@ def sample_graph_embeddings():
entities=[
EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/john"),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
]
)

View file

@ -108,7 +108,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
# Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')

View file

@ -60,7 +60,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
model="test-model",
input=["test text"]
)
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -86,7 +86,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
model="custom-model",
input=["test text"]
)
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@patch('trustgraph.embeddings.ollama.processor.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.milvus.service import Processor
from trustgraph.schema import DocumentEmbeddingsRequest
from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
class TestMilvusDocEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
)
return query
@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
)
# Verify results are document chunks
# Verify results are ChunkMatch objects
assert len(result) == 3
assert result[0] == "First document chunk"
assert result[1] == "Second document chunk"
assert result[2] == "Third document chunk"
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == "First document chunk"
assert result[1].chunk_id == "Second document chunk"
assert result[2].chunk_id == "Third document chunk"
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor):
"""Test querying document embeddings with multiple vectors"""
async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3
)
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"chunk_id": "Document from first vector"}},
{"entity": {"chunk_id": "Another doc from first vector"}},
# Mock search results
mock_results = [
{"entity": {"chunk_id": "First document"}},
{"entity": {"chunk_id": "Second document"}},
{"entity": {"chunk_id": "Third document"}},
]
mock_results_2 = [
{"entity": {"chunk_id": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results from all vectors are combined
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
)
# Verify results are ChunkMatch objects
assert len(result) == 3
assert "Document from first vector" in result
assert "Another doc from first vector" in result
assert "Document from second vector" in result
chunk_ids = [r.chunk_id for r in result]
assert "First document" in chunk_ids
assert "Second document" in chunk_ids
assert "Third document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor):
@ -141,7 +135,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=2
)
@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
vector=[],
limit=5
)
@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -211,7 +205,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -225,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify Unicode content is preserved
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with Unicode: éñ中文🚀" in result
assert "Regular ASCII document" in result
assert "Document with émojis: 😀🎉" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with Unicode: éñ中文🚀" in chunk_ids
assert "Regular ASCII document" in chunk_ids
assert "Document with émojis: 😀🎉" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_large_documents(self, processor):
@ -237,7 +232,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -251,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify large content is preserved
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
assert large_doc in result
assert "Small document" in result
chunk_ids = [r.chunk_id for r in result]
assert large_doc in chunk_ids
assert "Small document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_special_characters(self, processor):
@ -262,7 +258,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -276,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
result = await processor.query_document_embeddings(query)
# Verify special characters are preserved
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with \"quotes\" and 'apostrophes'" in result
assert "Document with\nnewlines\tand\ttabs" in result
assert "Document with special chars: @#$%^&*()" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
assert "Document with\nnewlines\tand\ttabs" in chunk_ids
assert "Document with special chars: @#$%^&*()" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=0
)
@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=-1
)
@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
# Mock search results
mock_results = [
{"entity": {"chunk_id": "Document 1"}},
{"entity": {"chunk_id": "Document 2"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
assert "Document from 2D vector" in result
assert "Document from 4D vector" in result
assert "Document from 3D vector" in result
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
# Verify results are ChunkMatch objects
assert len(result) == 2
chunk_ids = [r.chunk_id for r in result]
assert "Document 1" in chunk_ids
assert "Document 2" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_duplicate_documents(self, processor):
"""Test querying document embeddings with duplicate documents in results"""
async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results with duplicates across vectors
mock_results_1 = [
# Mock search results with multiple documents
mock_results = [
{"entity": {"chunk_id": "Document A"}},
{"entity": {"chunk_id": "Document B"}},
]
mock_results_2 = [
{"entity": {"chunk_id": "Document B"}}, # Duplicate
{"entity": {"chunk_id": "Document C"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
# This preserves ranking and allows multiple occurrences
assert len(result) == 4
assert result.count("Document B") == 2 # Should appear twice
assert "Document A" in result
assert "Document C" in result
# Verify results are ChunkMatch objects
assert len(result) == 3
chunk_ids = [r.chunk_id for r in result]
assert "Document A" in chunk_ids
assert "Document B" in chunk_ids
assert "Document C" in chunk_ids
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -103,7 +103,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -208,7 +208,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
@ -226,7 +226,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
@ -285,7 +285,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.vector = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -304,7 +304,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -325,7 +325,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_unicode_content(self, processor):
"""Test querying document embeddings with Unicode content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -351,7 +351,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_large_content(self, processor):
"""Test querying document embeddings with large content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 1
message.user = 'test_user'
message.collection = 'test_collection'
@ -377,7 +377,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_mixed_content_types(self, processor):
"""Test querying document embeddings with mixed content types"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -409,7 +409,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -425,7 +425,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_index_access_failure(self, processor):
"""Test handling of index access failure"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.doc_embeddings.qdrant.service import Processor
from trustgraph.schema import ChunkMatch
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected documents
# Verify result contains expected ChunkMatch objects
assert len(result) == 2
# Results should be strings (document chunks)
assert isinstance(result[0], str)
assert isinstance(result[1], str)
# Results should be ChunkMatch objects
assert isinstance(result[0], ChunkMatch)
assert isinstance(result[1], ChunkMatch)
# Verify content
assert result[0] == 'first document chunk'
assert result[1] == 'second document chunk'
assert result[0].chunk_id == 'first document chunk'
assert result[1].chunk_id == 'second document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with multiple vectors"""
async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings returns multiple results"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses for different vectors
# Mock query response with multiple results
mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from vector 1'}
mock_point1.payload = {'chunk_id': 'document chunk 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from vector 2'}
mock_point2.payload = {'chunk_id': 'document chunk 2'}
mock_point3 = MagicMock()
mock_point3.payload = {'chunk_id': 'another document from vector 2'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2, mock_point3]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point3.payload = {'chunk_id': 'document chunk 3'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2, mock_point3]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried correctly
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify results from both vectors are combined
# Verify results are ChunkMatch objects
assert len(result) == 3
assert 'document from vector 1' in result
assert 'document from vector 2' in result
assert 'another document from vector 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document chunk 1' in chunk_ids
assert 'document chunk 2' in chunk_ids
assert 'document chunk 3' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with different vector dimensions"""
"""Test querying document embeddings with a higher dimension vector"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'chunk_id': 'document from 2D vector'}
mock_point1.payload = {'chunk_id': 'document from 5D vector'}
mock_point2 = MagicMock()
mock_point2.payload = {'chunk_id': 'document from 3D vector'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point2.payload = {'chunk_id': 'another 5D document'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with 5D vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once with correct collection
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Call should use 5D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
# Verify results are ChunkMatch objects
assert len(result) == 2
assert 'document from 2D vector' in result
assert 'document from 3D vector' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document from 5D vector' in chunk_ids
assert 'another 5D document' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'utf8_user'
mock_message.collection = 'utf8_collection'
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 2
# Verify UTF-8 content works correctly
assert 'Document with UTF-8: café, naïve, résumé' in result
assert 'Chinese text: 你好世界' in result
# Verify UTF-8 content works correctly in ChunkMatch objects
chunk_ids = [r.chunk_id for r in result]
assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
assert 'Chinese text: 你好世界' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 0
# Result should contain all returned documents
# Result should contain all returned documents as ChunkMatch objects
assert len(result) == 1
assert result[0] == 'document chunk'
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == 'document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with large limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 1000 # Large limit
mock_message.user = 'large_user'
mock_message.collection = 'large_collection'
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 1000
# Result should contain all available documents
# Result should contain all available documents as ChunkMatch objects
assert len(result) == 2
assert 'document 1' in result
assert 'document 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document 1' in chunk_ids
assert 'document 2' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -508,7 +504,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'payload_user'
mock_message.collection = 'payload_collection'

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch
class TestMilvusGraphEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
)
return query
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify results are converted to Term objects
# Verify results are converted to EntityMatch objects
assert len(result) == 3
assert isinstance(result[0], Term)
assert result[0].iri == "http://example.com/entity1"
assert result[0].type == IRI
assert isinstance(result[1], Term)
assert result[1].iri == "http://example.com/entity2"
assert result[1].type == IRI
assert isinstance(result[2], Term)
assert result[2].value == "literal entity"
assert result[2].type == LITERAL
assert isinstance(result[0], EntityMatch)
assert result[0].entity.iri == "http://example.com/entity1"
assert result[0].entity.type == IRI
assert isinstance(result[1], EntityMatch)
assert result[1].entity.iri == "http://example.com/entity2"
assert result[1].entity.type == IRI
assert isinstance(result[2], EntityMatch)
assert result[2].entity.value == "literal entity"
assert result[2].entity.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
"""Test querying graph embeddings with multiple vectors"""
async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=3
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results - different results for each vector
mock_results_1 = [
# Mock search results with multiple entities
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results are deduplicated and limited
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
)
# Verify results are EntityMatch objects
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=2
)
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication(self, processor):
"""Test that duplicate entities are properly deduplicated"""
async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results with duplicates
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, # New
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=2
)
# Mock search results - first vector returns enough results
mock_results_1 = [
# Mock search results in specific order
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results_1
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
# Verify results are in the same order as returned by the store
assert len(result) == 3
assert result[0].entity.iri == "http://example.com/entity1"
assert result[1].entity.iri == "http://example.com/entity2"
assert result[2].entity.iri == "http://example.com/entity3"
@pytest.mark.asyncio
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
)
# Verify results are limited
# Mock search results - returns more results than limit
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
)
# Verify results are limited to requested amount
assert len(result) == 2
@pytest.mark.asyncio
@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
vector=[],
limit=5
)
@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify all results are properly typed
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.type == IRI]
uri_results = [r for r in result if r.entity.type == IRI]
assert len(uri_results) == 2
uri_values = [r.iri for r in uri_results]
uri_values = [r.entity.iri for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.type == IRI]
literal_results = [r for r in result if not r.entity.type == IRI]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
literal_values = [r.entity.value for r in literal_results]
assert "literal entity text" in literal_values
assert "another literal" in literal_values
@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=0
)
@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying graph embeddings with different vector dimensions"""
async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
# Mock search results
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "entity_2d" in entity_values
assert "entity_4d" in entity_values
assert "entity_3d" in entity_values
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once()
# Verify results
assert len(result) == 2
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values

View file

@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
class TestPineconeGraphEmbeddingsQueryProcessor:
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
include_metadata=True
)
# Verify results
# Verify results use EntityMatch structure
assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1'
assert entities[0].type == IRI
assert entities[1].value == 'entity2'
assert entities[1].type == LITERAL
assert entities[2].value == 'http://example.org/entity3'
assert entities[2].type == IRI
assert entities[0].entity.iri == 'http://example.org/entity1'
assert entities[0].entity.type == IRI
assert entities[1].entity.value == 'entity2'
assert entities[1].entity.type == LITERAL
assert entities[2].entity.iri == 'http://example.org/entity3'
assert entities[2].entity.type == IRI
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
"""Test querying graph embeddings with multiple vectors"""
async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
"""Test basic graph embeddings query"""
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query results
mock_results1 = MagicMock()
mock_results1.matches = [
# Query results with distinct entities
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'})
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify deduplication occurred
entity_values = [e.value for e in entities]
# Verify query was made once
assert mock_index.query.call_count == 1
# Verify results with EntityMatch structure
entity_values = [e.entity.value for e in entities]
assert len(entity_values) == 3
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions using same index"""
async def test_query_graph_embeddings_2d_vector(self, processor):
"""Test querying with a 2D vector"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.vector = [0.1, 0.2] # 2D vector
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
# Mock single index that handles all dimensions
# Mock index
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
# Mock results for 2D vector query
mock_results = MagicMock()
mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
index_calls = processor.pinecone.Index.call_args_list
index_names = [call[0][0] for call in index_calls]
assert "t-test_user-test_collection-2" in index_names # 2D vector
assert "t-test_user-test_collection-4" in index_names # 4D vector
# Verify correct index used for 2D vector
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify query was made
assert mock_index.query.call_count == 1
# Verify results from both dimensions
entity_values = [e.value for e in entities]
# Verify results with EntityMatch structure
entity_values = [e.entity.value for e in entities]
assert 'entity_2d' in entity_values
assert 'entity_4d' in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.vector = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
"""Test that deduplication works correctly across multiple vector queries"""
async def test_query_graph_embeddings_deduplication_in_results(self, processor):
"""Test that deduplication works correctly within query results"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Both queries return overlapping results
mock_results1 = MagicMock()
mock_results1.matches = [
# Query returns results with some duplicates
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}),
MagicMock(metadata={'entity': 'entity4'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
MagicMock(metadata={'entity': 'entity5'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
entity_values = [e.value for e in entities]
entity_values = [e.entity.value for e in entities]
assert len(set(entity_values)) == 3 # All unique
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
async def test_query_graph_embeddings_respects_limit(self, processor):
"""Test that query respects limit parameter"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query returns enough results to meet limit
mock_results1 = MagicMock()
mock_results1.matches = [
# Query returns more results than limit
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.return_value = mock_results1
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Should only make one query since limit was reached
# Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once()
assert len(entities) == 2
@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'

View file

@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor
from trustgraph.schema import IRI, LITERAL
from trustgraph.schema import IRI, LITERAL, EntityMatch
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected entities
# Verify result contains expected EntityMatch objects
assert len(result) == 2
assert all(hasattr(entity, 'value') for entity in result)
entity_values = [entity.value for entity in result]
assert all(isinstance(entity, EntityMatch) for entity in result)
entity_values = [entity.entity.value for entity in result]
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify deduplication - entity2 appears in both results but should only appear once
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert len(set(entity_values)) == len(entity_values) # All unique
assert 'entity1' in entity_values
assert 'entity2' in entity_values
assert 'entity3' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2] # 2D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
# Call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert 'entity2d' in entity_values
assert 'entity3d' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'uri_user'
mock_message.collection = 'uri_collection'
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 3
# Check URI entities
uri_entities = [entity for entity in result if entity.type == IRI]
uri_entities = [entity for entity in result if entity.entity.type == IRI]
assert len(uri_entities) == 2
uri_values = [entity.iri for entity in uri_entities]
uri_values = [entity.entity.iri for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities
regular_entities = [entity for entity in result if entity.type == LITERAL]
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity'
assert regular_entities[0].entity.value == 'regular entity'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# With zero limit, the logic still adds one entity before checking the limit
# So it returns one result (current behavior, not ideal but actual)
assert len(result) == 1
assert result[0].value == 'entity1'
assert result[0].entity.value == 'entity1'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')

View file

@ -175,9 +175,14 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock document embeddings returns chunk_ids
test_chunk_ids = ["doc/c1", "doc/c2"]
mock_doc_embeddings_client.query.return_value = test_chunk_ids
# Mock document embeddings returns ChunkMatch objects
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c1"
mock_match1.score = 0.95
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c2"
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
@ -195,9 +200,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify doc embeddings client was called correctly (with extracted vectors)
# Verify doc embeddings client was called correctly (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
vector=test_vectors,
limit=15,
user="test_user",
collection="test_collection"
@ -218,11 +223,16 @@ class TestQuery:
# Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]]
test_chunk_ids = ["doc/c3", "doc/c4"]
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c4"
mock_match2.score = 0.8
expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = test_chunk_ids
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag
@ -245,9 +255,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify doc embeddings client was called (with extracted vectors)
# Verify doc embeddings client was called (with extracted vector)
mock_doc_embeddings_client.query.assert_called_once_with(
test_vectors,
vector=test_vectors,
limit=10,
user="test_user",
collection="test_collection"
@ -275,7 +285,10 @@ class TestQuery:
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag
@ -289,9 +302,9 @@ class TestQuery:
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query")
# Verify default parameters were used (vectors extracted from batch)
# Verify default parameters were used (vector extracted from batch)
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2]],
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
user="trustgraph", # Default user
collection="default" # Default collection
@ -316,7 +329,10 @@ class TestQuery:
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True
query = Query(
@ -347,7 +363,10 @@ class TestQuery:
# Mock responses (batch format)
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True
@ -487,7 +506,13 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors]
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
mock_matches = []
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock()
mock_match.chunk_id = chunk_id
mock_match.score = 0.9
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches
mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag
@ -511,7 +536,7 @@ class TestQuery:
mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with(
query_vectors,
vector=query_vectors,
limit=25,
user="research_user",
collection="ml_knowledge"

View file

@ -193,12 +193,20 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock entity objects that have string representation
# Mock EntityMatch objects with entity that has string representation
mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
@ -216,9 +224,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify graph embeddings client was called correctly (with extracted vectors)
# Verify graph embeddings client was called correctly (with extracted vector)
mock_graph_embeddings_client.query.assert_called_once_with(
vectors=test_vectors,
vector=test_vectors,
limit=25,
user="test_user",
collection="test_collection"

View file

@ -23,11 +23,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk_id="This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
chunk2 = ChunkEmbeddings(
chunk_id="This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.chunks = [chunk1, chunk2]
@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk_id="Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called for each vector with user/collection parameters
expected_calls = [
([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_doc
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
# Verify insert was called once for the single chunk with its vector
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'
)
@pytest.mark.asyncio
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
"""Test storing document embeddings for multiple chunks"""
await processor.store_document_embeddings(mock_message)
# Verify insert was called for each vector of each chunk with user/collection parameters
# Verify insert was called once per chunk with user/collection parameters
expected_calls = [
# Chunk 1 vectors
([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
# Chunk 2 vectors
# Chunk 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
# Chunk 2 - single vector
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 3
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
@ -137,7 +127,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -156,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id=None,
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -177,15 +167,15 @@ class TestMilvusDocEmbeddingsStorageProcessor:
valid_chunk = ChunkEmbeddings(
chunk_id="Valid document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
empty_chunk = ChunkEmbeddings(
chunk_id="",
vectors=[[0.4, 0.5, 0.6]]
vector=[0.4, 0.5, 0.6]
)
another_valid = ChunkEmbeddings(
chunk_id="Another valid chunk",
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.chunks = [valid_chunk, empty_chunk, another_valid]
@ -229,7 +219,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="Document with no vectors",
vectors=[]
vector=[]
)
message.chunks = [chunk]
@ -245,26 +235,31 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk_id="Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each chunk has a single vector of different dimensions
chunk1 = ChunkEmbeddings(
chunk_id="chunk/doc/2d",
vector=[0.1, 0.2] # 2D vector
)
message.chunks = [chunk]
chunk2 = ChunkEmbeddings(
chunk_id="chunk/doc/4d",
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
chunk3 = ChunkEmbeddings(
chunk_id="chunk/doc/3d",
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.chunks = [chunk1, chunk2, chunk3]
await processor.store_document_embeddings(message)
# Verify all vectors were inserted regardless of dimension with user/collection parameters
expected_calls = [
([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'),
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'),
([0.1, 0.2], "chunk/doc/2d", 'test_user', 'test_collection'),
([0.3, 0.4, 0.5, 0.6], "chunk/doc/4d", 'test_user', 'test_collection'),
([0.7, 0.8, 0.9], "chunk/doc/3d", 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
@ -283,7 +278,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="chunk/doc/unicode-éñ中文🚀",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -306,7 +301,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
long_chunk_id = "chunk/doc/" + "a" * 200
chunk = ChunkEmbeddings(
chunk_id=long_chunk_id,
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -327,7 +322,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id=" \n\t ",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -358,7 +353,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="Test content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -379,7 +374,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings(
chunk_id="User1 content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message1.chunks = [chunk1]
@ -390,7 +385,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings(
chunk_id="User2 content",
vectors=[[0.4, 0.5, 0.6]]
vector=[0.4, 0.5, 0.6]
)
message2.chunks = [chunk2]
@ -421,7 +416,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk_id="Special chars test",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]

View file

@ -27,11 +27,11 @@ class TestPineconeDocEmbeddingsStorageProcessor:
# Create test document embeddings
chunk1 = ChunkEmbeddings(
chunk=b"This is the first document chunk",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
chunk2 = ChunkEmbeddings(
chunk=b"This is the second document chunk",
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.chunks = [chunk1, chunk2]
@ -125,7 +125,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
message.chunks = [chunk]
@ -190,7 +190,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -222,7 +222,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -244,7 +244,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=None,
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -266,7 +266,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"", # Empty bytes
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -286,37 +286,39 @@ class TestPineconeDocEmbeddingsStorageProcessor:
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Document with mixed dimensions",
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each chunk has a single vector of different dimensions
chunk1 = ChunkEmbeddings(
chunk=b"Document chunk 1",
vector=[0.1, 0.2] # 2D vector
)
message.chunks = [chunk]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
chunk2 = ChunkEmbeddings(
chunk=b"Document chunk 2",
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
chunk3 = ChunkEmbeddings(
chunk=b"Document chunk 3",
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.chunks = [chunk1, chunk2, chunk3]
mock_index = MagicMock()
def mock_index_side_effect(name):
# All dimensions now use the same index name pattern
# Different dimensions will be handled within the same index
if "test_user" in name and "test_collection" in name:
return mock_index_2d # Just return one mock for all
return mock_index
return MagicMock()
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(message)
# Verify all vectors are now stored in the same index
# (Pinecone can handle mixed dimensions in the same index)
assert processor.pinecone.Index.call_count == 3 # Called once per vector
mock_index_2d.upsert.call_count == 3 # All upserts go to same index
# (Each chunk has a single vector, called once per chunk)
assert processor.pinecone.Index.call_count == 3 # Called once per chunk
assert mock_index.upsert.call_count == 3 # All upserts go to same index
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):
@ -346,7 +348,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Document with no vectors",
vectors=[]
vector=[]
)
message.chunks = [chunk]
@ -368,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -393,7 +395,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -419,7 +421,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
chunk = ChunkEmbeddings(
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
@ -447,7 +449,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
large_content = "A" * 10000 # 10KB of content
chunk = ChunkEmbeddings(
chunk=large_content.encode('utf-8'),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]

View file

@ -89,7 +89,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_chunk.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
mock_message.chunks = [mock_chunk]
@ -143,11 +143,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2]]
mock_chunk1.vector = [0.1, 0.2]
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.3, 0.4]]
mock_chunk2.vector = [0.3, 0.4]
mock_message.chunks = [mock_chunk1, mock_chunk2]
@ -175,8 +175,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple vectors per chunk"""
async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple chunks"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
@ -196,41 +196,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with chunk having multiple vectors
# Create mock message with multiple chunks, each having a single vector
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/multi-vector'
mock_chunk.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vector = [0.1, 0.2, 0.3]
mock_message.chunks = [mock_chunk]
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vector = [0.4, 0.5, 0.6]
mock_chunk3 = MagicMock()
mock_chunk3.chunk_id = 'doc/c3'
mock_chunk3.vector = [0.7, 0.8, 0.9]
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should be called 3 times (once per vector)
# Should be called 3 times (once per chunk)
assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
expected_data = [
([0.1, 0.2, 0.3], 'doc/c1'),
([0.4, 0.5, 0.6], 'doc/c2'),
([0.7, 0.8, 0.9], 'doc/c3')
]
for i, call in enumerate(upsert_calls):
point = call[1]['points'][0]
assert point.vector == expected_vectors[i]
assert point.payload['chunk_id'] == 'doc/multi-vector'
assert point.vector == expected_data[i][0]
assert point.payload['chunk_id'] == expected_data[i][1]
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
@ -256,7 +260,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk_id = "" # Empty chunk_id
mock_chunk_empty.vectors = [[0.1, 0.2]]
mock_chunk_empty.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk_empty]
@ -299,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions
mock_message.chunks = [mock_chunk]
@ -351,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/test-chunk'
mock_chunk.vectors = [[0.1, 0.2]]
mock_chunk.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk]
@ -389,7 +393,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
mock_chunk1.vector = [0.1, 0.2, 0.3]
mock_message1.chunks = [mock_chunk1]
@ -407,7 +411,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
mock_chunk2.vector = [0.4, 0.5, 0.6] # Same dimension (3)
mock_message2.chunks = [mock_chunk2]
@ -446,19 +450,20 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push)
processor.known_collections[('dim_user', 'dim_collection')] = {}
# Create mock message with different dimension vectors
# Create mock message with chunks of different dimensions
mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection'
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'doc/dim-test'
mock_chunk.vectors = [
[0.1, 0.2], # 2 dimensions
[0.3, 0.4, 0.5] # 3 dimensions
]
mock_chunk1 = MagicMock()
mock_chunk1.chunk_id = 'doc/c1'
mock_chunk1.vector = [0.1, 0.2] # 2 dimensions
mock_message.chunks = [mock_chunk]
mock_chunk2 = MagicMock()
mock_chunk2.chunk_id = 'doc/c2'
mock_chunk2.vector = [0.3, 0.4, 0.5] # 3 dimensions
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
@ -526,7 +531,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_chunk = MagicMock()
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
mock_chunk.vectors = [[0.1, 0.2]]
mock_chunk.vector = [0.1, 0.2]
mock_message.chunks = [mock_chunk]

View file

@ -23,11 +23,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
# Create test entities with embeddings
entity1 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity1'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
entity2 = EntityEmbeddings(
entity=Term(type=LITERAL, value='literal entity'),
vectors=[[0.7, 0.8, 0.9]]
vector=[0.7, 0.8, 0.9]
)
message.entities = [entity1, entity2]
@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
# Verify insert was called for each vector with user/collection parameters
expected_calls = [
([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
assert actual_call[0][1] == expected_entity
assert actual_call[0][2] == expected_user
assert actual_call[0][3] == expected_collection
# Verify insert was called once with the full vector
processor.vecstore.insert.assert_called_once()
actual_call = processor.vecstore.insert.call_args_list[0]
assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
assert actual_call[0][1] == 'http://example.com/entity'
assert actual_call[0][2] == 'test_user'
assert actual_call[0][3] == 'test_collection'
@pytest.mark.asyncio
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
"""Test storing graph embeddings for multiple entities"""
await processor.store_graph_embeddings(mock_message)
# Verify insert was called for each vector of each entity with user/collection parameters
# Verify insert was called once per entity with user/collection parameters
expected_calls = [
# Entity 1 vectors
([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'),
([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
# Entity 2 vectors
# Entity 1 - single vector
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
# Entity 2 - single vector
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
]
assert processor.vecstore.insert.call_count == 3
assert processor.vecstore.insert.call_count == 2
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
assert actual_call[0][0] == expected_vec
@ -137,7 +130,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=''),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -156,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=None),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -175,17 +168,17 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
valid_entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/valid'),
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
chunk_id=''
)
empty_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=''),
vectors=[[0.4, 0.5, 0.6]],
vector=[0.4, 0.5, 0.6],
chunk_id=''
)
none_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value=None),
vectors=[[0.7, 0.8, 0.9]],
vector=[0.7, 0.8, 0.9],
chunk_id=''
)
message.entities = [valid_entity, empty_entity, none_entity]
@ -222,7 +215,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[]
vector=[]
)
message.entities = [entity]
@ -238,26 +231,31 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each entity has a single vector of different dimensions
entity1 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity1'),
vector=[0.1, 0.2] # 2D vector
)
message.entities = [entity]
entity2 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity2'),
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
entity3 = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/entity3'),
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.entities = [entity1, entity2, entity3]
await processor.store_graph_embeddings(message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
([0.1, 0.2], 'http://example.com/entity'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
([0.7, 0.8, 0.9], 'http://example.com/entity'),
([0.1, 0.2], 'http://example.com/entity1'),
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity2'),
([0.7, 0.8, 0.9], 'http://example.com/entity3'),
]
assert processor.vecstore.insert.call_count == 3
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
actual_call = processor.vecstore.insert.call_args_list[i]
@ -274,11 +272,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
uri_entity = EntityEmbeddings(
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
literal_entity = EntityEmbeddings(
entity=Term(type=LITERAL, value='literal entity text'),
vectors=[[0.4, 0.5, 0.6]]
vector=[0.4, 0.5, 0.6]
)
message.entities = [uri_entity, literal_entity]

View file

@ -24,16 +24,20 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings
# Create test entity embeddings (each entity has a single vector)
entity1 = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3]
)
entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vectors=[[0.7, 0.8, 0.9]]
entity=Value(value="http://example.org/entity2", is_uri=True),
vector=[0.4, 0.5, 0.6]
)
message.entities = [entity1, entity2]
entity3 = EntityEmbeddings(
entity=Value(value="entity3", is_uri=False),
vector=[0.7, 0.8, 0.9]
)
message.entities = [entity1, entity2, entity3]
return message
@ -122,27 +126,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="http://example.org/entity1", is_uri=True),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
# Mock index operations
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
with patch('uuid.uuid4', side_effect=['id1']):
await processor.store_graph_embeddings(message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
assert mock_index.upsert.call_count == 2
# Verify upsert was called for the single vector
assert mock_index.upsert.call_count == 1
# Check first vector upsert
first_call = mock_index.upsert.call_args_list[0]
@ -190,7 +194,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -222,7 +226,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -244,7 +248,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -258,23 +262,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions to same index"""
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
]
# Each entity has a single vector of different dimensions
entity1 = EntityEmbeddings(
entity=Value(value="entity1", is_uri=False),
vector=[0.1, 0.2] # 2D vector
)
message.entities = [entity]
entity2 = EntityEmbeddings(
entity=Value(value="entity2", is_uri=False),
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
)
entity3 = EntityEmbeddings(
entity=Value(value="entity3", is_uri=False),
vector=[0.7, 0.8, 0.9] # 3D vector
)
message.entities = [entity1, entity2, entity3]
# All vectors now use the same index (no dimension in name)
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
@ -322,7 +330,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[]
vector=[]
)
message.entities = [entity]
@ -344,7 +352,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]
@ -369,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
message.entities = [entity]

View file

@ -70,7 +70,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity = MagicMock()
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'test_entity'
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_entity.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
mock_message.entities = [mock_entity]
@ -124,12 +124,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity1 = MagicMock()
mock_entity1.entity.type = IRI
mock_entity1.entity.iri = 'entity_one'
mock_entity1.vectors = [[0.1, 0.2]]
mock_entity1.vector = [0.1, 0.2]
mock_entity2 = MagicMock()
mock_entity2.entity.type = IRI
mock_entity2.entity.iri = 'entity_two'
mock_entity2.vectors = [[0.3, 0.4]]
mock_entity2.vector = [0.3, 0.4]
mock_message.entities = [mock_entity1, mock_entity2]
@ -157,14 +157,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with multiple vectors per entity"""
async def test_store_graph_embeddings_three_entities(self, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with three entities"""
# Arrange
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -177,42 +177,48 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with entity having multiple vectors
# Create mock message with three entities
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_entity = MagicMock()
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'multi_vector_entity'
mock_entity.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
mock_message.entities = [mock_entity]
mock_entity1 = MagicMock()
mock_entity1.entity.type = IRI
mock_entity1.entity.iri = 'entity_one'
mock_entity1.vector = [0.1, 0.2, 0.3]
mock_entity2 = MagicMock()
mock_entity2.entity.type = IRI
mock_entity2.entity.iri = 'entity_two'
mock_entity2.vector = [0.4, 0.5, 0.6]
mock_entity3 = MagicMock()
mock_entity3.entity.type = IRI
mock_entity3.entity.iri = 'entity_three'
mock_entity3.vector = [0.7, 0.8, 0.9]
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
# Act
await processor.store_graph_embeddings(mock_message)
# Assert
# Should be called 3 times (once per vector)
# Should be called 3 times (once per entity)
assert mock_qdrant_instance.upsert.call_count == 3
# Verify all vectors were processed
# Verify all entities were processed
upsert_calls = mock_qdrant_instance.upsert.call_args_list
expected_vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
expected_data = [
([0.1, 0.2, 0.3], 'entity_one'),
([0.4, 0.5, 0.6], 'entity_two'),
([0.7, 0.8, 0.9], 'entity_three')
]
for i, call in enumerate(upsert_calls):
point = call[1]['points'][0]
assert point.vector == expected_vectors[i]
assert point.payload['entity'] == 'multi_vector_entity'
assert point.vector == expected_data[i][0]
assert point.payload['entity'] == expected_data[i][1]
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
@ -238,11 +244,11 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_entity_empty = MagicMock()
mock_entity_empty.entity.type = LITERAL
mock_entity_empty.entity.value = "" # Empty string
mock_entity_empty.vectors = [[0.1, 0.2]]
mock_entity_empty.vector = [0.1, 0.2]
mock_entity_none = MagicMock()
mock_entity_none.entity = None # None entity
mock_entity_none.vectors = [[0.3, 0.4]]
mock_entity_none.vector = [0.3, 0.4]
mock_message.entities = [mock_entity_empty, mock_entity_none]

View file

@ -197,7 +197,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='customer_id',
index_value=['CUST001'],
text='CUST001',
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
embeddings_msg = RowEmbeddings(
@ -227,8 +227,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with multiple vectors"""
async def test_on_embeddings_single_vector(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with a single vector"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
@ -250,12 +250,12 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with multiple vectors
# Embedding with a single 6D vector
embedding = RowIndexEmbedding(
index_name='name',
index_value=['John Doe'],
text='John Doe',
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
)
embeddings_msg = RowEmbeddings(
@ -269,8 +269,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
# Should be called once for the single embedding
assert mock_qdrant_instance.upsert.call_count == 1
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
@ -299,7 +299,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='id',
index_value=['123'],
text='123',
vectors=[] # Empty vectors
vector=[] # Empty vector
)
embeddings_msg = RowEmbeddings(
@ -342,7 +342,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
index_name='id',
index_value=['123'],
text='123',
vectors=[[0.1, 0.2]]
vector=[0.1, 0.2]
)
embeddings_msg = RowEmbeddings(

View file

@ -612,12 +612,12 @@ class AsyncFlowInstance:
print(f"{entity['name']}: {entity['score']}")
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -810,12 +810,12 @@ class AsyncFlowInstance:
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
limit: int = 10, **kwargs
):
"""Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -602,13 +602,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query graph embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -648,13 +648,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query document embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -1362,13 +1362,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query row embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -649,12 +649,12 @@ class SocketFlowInstance:
)
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -698,12 +698,12 @@ class SocketFlowInstance:
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -936,12 +936,12 @@ class SocketFlowInstance:
)
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
logger = logging.getLogger(__name__)
class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
DocumentEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return resp.chunk_ids
# Return ChunkMatch objects with chunk_id and score
return resp.chunks
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error",
message = str(e),
),
chunk_ids=[],
chunks=[],
)
await flow("response").send(r, properties={"id": id})

View file

@ -19,12 +19,12 @@ def to_value(x):
return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
GraphEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return [
to_value(v)
for v in resp.entities
]
# Return EntityMatch objects with entity and score
return resp.entities
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600
):
request = RowEmbeddingsRequest(
vectors=vectors,
vector=vector,
schema_name=schema_name,
user=user,
collection=collection,

View file

@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
)
def request(
self, vectors, user="trustgraph", collection="default",
self, vector, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout
vector=vector, limit=limit, timeout=timeout
).chunks

View file

@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
)
def request(
self, vectors, user="trustgraph", collection="default",
self, vector, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout
vector=vector, limit=limit, timeout=timeout
).entities

View file

@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
)
def request(
self, vectors, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=300
):
kwargs = dict(
user=user, collection=collection,
vectors=vectors, schema_name=schema_name,
vector=vector, schema_name=schema_name,
limit=limit, timeout=timeout
)
if index_name:

View file

@ -10,18 +10,18 @@ from .primitives import ValueTranslator
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
return DocumentEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.chunk_ids is not None:
result["chunk_ids"] = list(obj.chunk_ids)
if obj.chunks is not None:
result["chunks"] = [
{
"chunk_id": chunk.chunk_id,
"score": chunk.score
}
for chunk in obj.chunks
]
return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
class GraphEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
return GraphEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
class GraphEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsResponse schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.entities is not None:
result["entities"] = [
self.value_translator.from_pulsar(entity)
for entity in obj.entities
{
"entity": self.value_translator.from_pulsar(match.entity),
"score": match.score
}
for match in obj.entities
]
return result
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
return RowEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
result = {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection,

View file

@ -11,7 +11,7 @@ from ..core.topic import topic
@dataclass
class EntityEmbeddings:
entity: Term | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
# Provenance: which chunk this embedding was derived from
chunk_id: str = ""
@ -28,7 +28,7 @@ class GraphEmbeddings:
@dataclass
class ChunkEmbeddings:
chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
# This is a 'batching' mechanism for the above data
@dataclass
@ -44,7 +44,7 @@ class DocumentEmbeddings:
@dataclass
class ObjectEmbeddings:
metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
name: str = ""
key_name: str = ""
id: str = ""
@ -56,7 +56,7 @@ class ObjectEmbeddings:
@dataclass
class StructuredObjectEmbedding:
metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
schema_name: str = ""
object_id: str = "" # Primary key value
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
@ -72,7 +72,7 @@ class RowIndexEmbedding:
index_name: str = "" # The indexed field name(s)
index_value: list[str] = field(default_factory=list) # The field value(s)
text: str = "" # Text that was embedded
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
@dataclass
class RowEmbeddings:

View file

@ -34,7 +34,7 @@ class EmbeddingsRequest:
@dataclass
class EmbeddingsResponse:
error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list)
vectors: list[list[float]] = field(default_factory=list)
############################################################################

View file

@ -9,15 +9,21 @@ from ..core.topic import topic
@dataclass
class GraphEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
limit: int = 0
user: str = ""
collection: str = ""
@dataclass
class EntityMatch:
"""A matching entity from a semantic search with similarity score"""
entity: Term | None = None
score: float = 0.0
@dataclass
class GraphEmbeddingsResponse:
error: Error | None = None
entities: list[Term] = field(default_factory=list)
entities: list[EntityMatch] = field(default_factory=list)
############################################################################
@ -44,15 +50,21 @@ class TriplesQueryResponse:
@dataclass
class DocumentEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
limit: int = 0
user: str = ""
collection: str = ""
@dataclass
class ChunkMatch:
"""A matching chunk from a semantic search with similarity score"""
chunk_id: str = ""
score: float = 0.0
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunk_ids: list[str] = field(default_factory=list)
chunks: list[ChunkMatch] = field(default_factory=list)
document_embeddings_request_queue = topic(
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
@ -76,7 +88,7 @@ class RowIndexMatch:
@dataclass
class RowEmbeddingsRequest:
"""Request for row embeddings semantic search"""
vectors: list[list[float]] = field(default_factory=list) # Query vectors
vector: list[float] = field(default_factory=list) # Query vector
limit: int = 10 # Max results to return
user: str = "" # User/keyspace
collection: str = "" # Collection name

View file

@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl:
query_text = arguments.get("query")
all_vectors = await embeddings_client.embed([query_text])
vectors = all_vectors[0] if all_vectors else []
vector = all_vectors[0] if all_vectors else []
# Now query row embeddings
client = self.context("row-embeddings-query-request")
@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl:
user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query(
vectors=vectors,
vector=vector,
schema_name=self.schema_name,
user=user,
collection=self.collection or "default",

View file

@ -66,13 +66,13 @@ class Processor(FlowProcessor):
)
)
# vectors[0] is the vector set for the first (only) text
vectors = resp.vectors[0] if resp.vectors else []
# vectors[0] is the vector for the first (only) text
vector = resp.vectors[0] if resp.vectors else []
embeds = [
ChunkEmbeddings(
chunk_id=v.document_id,
vectors=vectors,
vector=vector,
)
]

View file

@ -59,11 +59,8 @@ class Processor(EmbeddingsService):
# FastEmbed processes the full batch efficiently
vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text
return [
[v.tolist()]
for v in vecs
]
# Return list of vectors, one per input text
return [v.tolist() for v in vecs]
@staticmethod
def add_args(parser):

View file

@ -72,10 +72,10 @@ class Processor(FlowProcessor):
entities = [
EntityEmbeddings(
entity=entity.entity,
vectors=vectors, # Vector set for this entity
vector=vector,
chunk_id=entity.chunk_id, # Provenance: source chunk
)
for entity, vectors in zip(v.entities, all_vectors)
for entity, vector in zip(v.entities, all_vectors)
]
# Send in batches to avoid oversized messages

View file

@ -43,11 +43,8 @@ class Processor(EmbeddingsService):
input = texts
)
# Return list of vector sets, one per input text
return [
[embedding]
for embedding in embeds.embeddings
]
# Return list of vectors, one per input text
return list(embeds.embeddings)
@staticmethod
def add_args(parser):

View file

@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results with metadata
for text, (index_name, index_value), vectors in zip(
for text, (index_name, index_value), vector in zip(
texts, metadata, all_vectors
):
embeddings_list.append(
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors # Vector set for this text
vector=vector
)
)

View file

@ -7,7 +7,7 @@ of chunk_ids
import logging
from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
chunk_ids = []
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for vec in msg.vectors:
chunks = []
for r in resp:
chunk_id = r["entity"]["chunk_id"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for r in resp:
chunk_id = r["entity"]["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
return chunks
except Exception as e:

View file

@ -11,6 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import ChunkMatch
from .... base import DocumentEmbeddingsQueryService
# Module logger
@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
chunk_ids = []
dim = len(vec)
for vec in msg.vectors:
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
dim = len(vec)
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
index = self.pinecone.Index(index_name)
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
results = index.query(
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
index = self.pinecone.Index(index_name)
chunks = []
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
results = index.query(
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
return chunks
except Exception as e:

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService):
try:
chunk_ids = []
vec = msg.vector
if not vec:
return []
for vec in msg.vectors:
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
continue
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
chunks = []
for r in search_result:
chunk_id = r.payload["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
for r in search_result:
chunk_id = r.payload["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
return chunks
except Exception as e:

View file

@ -7,7 +7,7 @@ entities
import logging
from .... direct.milvus_graph_embeddings import EntityVectors
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
try:
entity_set = set()
entities = []
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
for vec in msg.vectors:
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
entity_set = set()
entities = []
for r in resp:
ent = r["entity"]["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
for r in resp:
ent = r["entity"]["entity"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
logger.debug("Send response...")
return entities

View file

@ -11,7 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
entity_set = set()
entities = []
for vec in msg.vectors:
for r in results.matches:
ent = r.metadata["entity"]
score = r.score if hasattr(r, 'score') else 0.0
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
return entities

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist")
return []
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
entity_set = set()
entities = []
for vec in msg.vectors:
for r in search_result:
ent = r.payload["entity"]
score = r.score if hasattr(r, 'score') else 0.0
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, skipping this vector")
continue
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
for r in search_result:
ent = r.payload["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
logger.debug("Send response...")
return entities

View file

@ -93,7 +93,9 @@ class Processor(FlowProcessor):
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
matches = []
vec = request.vector
if not vec:
return []
# Find the collection for this user/collection/schema
qdrant_collection = self.find_collection(
@ -105,47 +107,47 @@ class Processor(FlowProcessor):
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
)
return []
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
matches = []
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
return matches
for vec in request.vectors:
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""

View file

@ -37,26 +37,26 @@ class Query:
vectors = await self.get_vector(query)
if self.verbose:
logger.debug("Getting chunk_ids from embeddings store...")
logger.debug("Getting chunks from embeddings store...")
# Get chunk_ids from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit,
# Get chunk matches from embeddings store
chunk_matches = await self.rag.doc_embeddings_client.query(
vector=vectors, limit=self.doc_limit,
user=self.user, collection=self.collection,
)
if self.verbose:
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
# Fetch chunk content from Garage
docs = []
for chunk_id in chunk_ids:
if chunk_id:
for match in chunk_matches:
if match.chunk_id:
try:
content = await self.rag.fetch_chunk(chunk_id, self.user)
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
docs.append(content)
except Exception as e:
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
if self.verbose:
logger.debug("Documents fetched:")

View file

@ -87,14 +87,14 @@ class Query:
if self.verbose:
logger.debug("Getting entities...")
entities = await self.rag.graph_embeddings_client.query(
vectors=vectors, limit=self.entity_limit,
entity_matches = await self.rag.graph_embeddings_client.query(
vector=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection,
)
entities = [
str(e)
for e in entities
str(e.entity)
for e in entity_matches
]
if self.verbose:

View file

@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "":
continue
for vec in emb.vectors:
vec = emb.vector
if vec:
self.vecstore.insert(
vec, chunk_id,
message.metadata.user,

View file

@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "":
continue
for vec in emb.vectors:
vec = emb.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "chunk_id": chunk_id },
}
]
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "chunk_id": chunk_id },
}
]
index.upsert(
vectors = records,
)
index.upsert(
vectors = records,
)
@staticmethod
def add_args(parser):

View file

@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "":
continue
for vec in emb.vectors:
vec = emb.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"chunk_id": chunk_id,
}
)
]
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"chunk_id": chunk_id,
}
)
]
)
@staticmethod
def add_args(parser):

View file

@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
entity_value = get_term_value(entity.entity)
if entity_value != "" and entity_value is not None:
for vec in entity.vectors:
vec = entity.vector
if vec:
self.vecstore.insert(
vec, entity_value,
message.metadata.user,

View file

@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
vec = entity.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
metadata = {"entity": entity_value}
if entity.chunk_id:
metadata["chunk_id"] = entity.chunk_id
metadata = {"entity": entity_value}
if entity.chunk_id:
metadata["chunk_id"] = entity.chunk_id
records = [
{
"id": vector_id,
"values": vec,
"metadata": metadata,
}
]
records = [
{
"id": vector_id,
"values": vec,
"metadata": metadata,
}
]
index.upsert(
vectors = records,
)
index.upsert(
vectors = records,
)
@staticmethod
def add_args(parser):

View file

@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
vec = entity.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload=payload,
)
]
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload=payload,
)
]
)
@staticmethod
def add_args(parser):

View file

@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
qdrant_collection = None
for row_emb in embeddings.embeddings:
if not row_emb.vectors:
vector = row_emb.vector
if not vector:
logger.warning(
f"No vectors for index {row_emb.index_name} - skipping"
f"No vector for index {row_emb.index_name} - skipping"
)
continue
# Use first vector (there may be multiple from different models)
for vector in row_emb.vectors:
dimension = len(vector)
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
embeddings_written += 1
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")