mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
|
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error
|
||||||
from trustgraph.messaging.translators.embeddings_query import (
|
from trustgraph.messaging.translators.embeddings_query import (
|
||||||
DocumentEmbeddingsRequestTranslator,
|
DocumentEmbeddingsRequestTranslator,
|
||||||
DocumentEmbeddingsResponseTranslator
|
DocumentEmbeddingsResponseTranslator
|
||||||
|
|
@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
"""Test that DocumentEmbeddingsRequest has expected fields"""
|
"""Test that DocumentEmbeddingsRequest has expected fields"""
|
||||||
# Create a request
|
# Create a request
|
||||||
request = DocumentEmbeddingsRequest(
|
request = DocumentEmbeddingsRequest(
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=10,
|
limit=10,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify all expected fields exist
|
# Verify all expected fields exist
|
||||||
assert hasattr(request, 'vectors')
|
assert hasattr(request, 'vector')
|
||||||
assert hasattr(request, 'limit')
|
assert hasattr(request, 'limit')
|
||||||
assert hasattr(request, 'user')
|
assert hasattr(request, 'user')
|
||||||
assert hasattr(request, 'collection')
|
assert hasattr(request, 'collection')
|
||||||
|
|
||||||
# Verify field values
|
# Verify field values
|
||||||
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
assert request.vector == [0.1, 0.2, 0.3]
|
||||||
assert request.limit == 10
|
assert request.limit == 10
|
||||||
assert request.user == "test_user"
|
assert request.user == "test_user"
|
||||||
assert request.collection == "test_collection"
|
assert request.collection == "test_collection"
|
||||||
|
|
@ -43,7 +43,7 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
translator = DocumentEmbeddingsRequestTranslator()
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"vectors": [[0.1, 0.2], [0.3, 0.4]],
|
"vector": [0.1, 0.2, 0.3, 0.4],
|
||||||
"limit": 5,
|
"limit": 5,
|
||||||
"user": "custom_user",
|
"user": "custom_user",
|
||||||
"collection": "custom_collection"
|
"collection": "custom_collection"
|
||||||
|
|
@ -52,7 +52,7 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
result = translator.to_pulsar(data)
|
result = translator.to_pulsar(data)
|
||||||
|
|
||||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||||
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
|
assert result.vector == [0.1, 0.2, 0.3, 0.4]
|
||||||
assert result.limit == 5
|
assert result.limit == 5
|
||||||
assert result.user == "custom_user"
|
assert result.user == "custom_user"
|
||||||
assert result.collection == "custom_collection"
|
assert result.collection == "custom_collection"
|
||||||
|
|
@ -62,14 +62,14 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
translator = DocumentEmbeddingsRequestTranslator()
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"vectors": [[0.1, 0.2]]
|
"vector": [0.1, 0.2]
|
||||||
# No limit, user, or collection provided
|
# No limit, user, or collection provided
|
||||||
}
|
}
|
||||||
|
|
||||||
result = translator.to_pulsar(data)
|
result = translator.to_pulsar(data)
|
||||||
|
|
||||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||||
assert result.vectors == [[0.1, 0.2]]
|
assert result.vector == [0.1, 0.2]
|
||||||
assert result.limit == 10 # Default
|
assert result.limit == 10 # Default
|
||||||
assert result.user == "trustgraph" # Default
|
assert result.user == "trustgraph" # Default
|
||||||
assert result.collection == "default" # Default
|
assert result.collection == "default" # Default
|
||||||
|
|
@ -79,7 +79,7 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
translator = DocumentEmbeddingsRequestTranslator()
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
request = DocumentEmbeddingsRequest(
|
request = DocumentEmbeddingsRequest(
|
||||||
vectors=[[0.5, 0.6]],
|
vector=[0.5, 0.6],
|
||||||
limit=20,
|
limit=20,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
|
|
@ -88,7 +88,7 @@ class TestDocumentEmbeddingsRequestContract:
|
||||||
result = translator.from_pulsar(request)
|
result = translator.from_pulsar(request)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert result["vectors"] == [[0.5, 0.6]]
|
assert result["vector"] == [0.5, 0.6]
|
||||||
assert result["limit"] == 20
|
assert result["limit"] == 20
|
||||||
assert result["user"] == "test_user"
|
assert result["user"] == "test_user"
|
||||||
assert result["collection"] == "test_collection"
|
assert result["collection"] == "test_collection"
|
||||||
|
|
@ -99,19 +99,25 @@ class TestDocumentEmbeddingsResponseContract:
|
||||||
|
|
||||||
def test_response_schema_fields(self):
|
def test_response_schema_fields(self):
|
||||||
"""Test that DocumentEmbeddingsResponse has expected fields"""
|
"""Test that DocumentEmbeddingsResponse has expected fields"""
|
||||||
# Create a response with chunk_ids
|
# Create a response with chunks
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunk_ids=["chunk1", "chunk2", "chunk3"]
|
chunks=[
|
||||||
|
ChunkMatch(chunk_id="chunk1", score=0.9),
|
||||||
|
ChunkMatch(chunk_id="chunk2", score=0.8),
|
||||||
|
ChunkMatch(chunk_id="chunk3", score=0.7)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify all expected fields exist
|
# Verify all expected fields exist
|
||||||
assert hasattr(response, 'error')
|
assert hasattr(response, 'error')
|
||||||
assert hasattr(response, 'chunk_ids')
|
assert hasattr(response, 'chunks')
|
||||||
|
|
||||||
# Verify field values
|
# Verify field values
|
||||||
assert response.error is None
|
assert response.error is None
|
||||||
assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"]
|
assert len(response.chunks) == 3
|
||||||
|
assert response.chunks[0].chunk_id == "chunk1"
|
||||||
|
assert response.chunks[0].score == 0.9
|
||||||
|
|
||||||
def test_response_schema_with_error(self):
|
def test_response_schema_with_error(self):
|
||||||
"""Test response schema with error"""
|
"""Test response schema with error"""
|
||||||
|
|
@ -122,53 +128,59 @@ class TestDocumentEmbeddingsResponseContract:
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=error,
|
error=error,
|
||||||
chunk_ids=[]
|
chunks=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.error == error
|
assert response.error == error
|
||||||
assert response.chunk_ids == []
|
assert response.chunks == []
|
||||||
|
|
||||||
def test_response_translator_from_pulsar_with_chunk_ids(self):
|
def test_response_translator_from_pulsar_with_chunks(self):
|
||||||
"""Test response translator converts Pulsar schema with chunk_ids to dict"""
|
"""Test response translator converts Pulsar schema with chunks to dict"""
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"]
|
chunks=[
|
||||||
|
ChunkMatch(chunk_id="doc1/c1", score=0.95),
|
||||||
|
ChunkMatch(chunk_id="doc2/c2", score=0.85),
|
||||||
|
ChunkMatch(chunk_id="doc3/c3", score=0.75)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
result = translator.from_pulsar(response)
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert "chunk_ids" in result
|
assert "chunks" in result
|
||||||
assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"]
|
assert len(result["chunks"]) == 3
|
||||||
|
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
|
||||||
|
assert result["chunks"][0]["score"] == 0.95
|
||||||
|
|
||||||
def test_response_translator_from_pulsar_with_empty_chunk_ids(self):
|
def test_response_translator_from_pulsar_with_empty_chunks(self):
|
||||||
"""Test response translator handles empty chunk_ids list"""
|
"""Test response translator handles empty chunks list"""
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunk_ids=[]
|
chunks=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
result = translator.from_pulsar(response)
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert "chunk_ids" in result
|
assert "chunks" in result
|
||||||
assert result["chunk_ids"] == []
|
assert result["chunks"] == []
|
||||||
|
|
||||||
def test_response_translator_from_pulsar_with_none_chunk_ids(self):
|
def test_response_translator_from_pulsar_with_none_chunks(self):
|
||||||
"""Test response translator handles None chunk_ids"""
|
"""Test response translator handles None chunks"""
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.chunk_ids = None
|
response.chunks = None
|
||||||
|
|
||||||
result = translator.from_pulsar(response)
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert "chunk_ids" not in result or result.get("chunk_ids") is None
|
assert "chunks" not in result or result.get("chunks") is None
|
||||||
|
|
||||||
def test_response_translator_from_response_with_completion(self):
|
def test_response_translator_from_response_with_completion(self):
|
||||||
"""Test response translator with completion flag"""
|
"""Test response translator with completion flag"""
|
||||||
|
|
@ -176,14 +188,18 @@ class TestDocumentEmbeddingsResponseContract:
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunk_ids=["chunk1", "chunk2"]
|
chunks=[
|
||||||
|
ChunkMatch(chunk_id="chunk1", score=0.9),
|
||||||
|
ChunkMatch(chunk_id="chunk2", score=0.8)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
result, is_final = translator.from_response_with_completion(response)
|
result, is_final = translator.from_response_with_completion(response)
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert "chunk_ids" in result
|
assert "chunks" in result
|
||||||
assert result["chunk_ids"] == ["chunk1", "chunk2"]
|
assert len(result["chunks"]) == 2
|
||||||
|
assert result["chunks"][0]["chunk_id"] == "chunk1"
|
||||||
assert is_final is True # Document embeddings responses are always final
|
assert is_final is True # Document embeddings responses are always final
|
||||||
|
|
||||||
def test_response_translator_to_pulsar_not_implemented(self):
|
def test_response_translator_to_pulsar_not_implemented(self):
|
||||||
|
|
@ -191,7 +207,7 @@ class TestDocumentEmbeddingsResponseContract:
|
||||||
translator = DocumentEmbeddingsResponseTranslator()
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
translator.to_pulsar({"chunk_ids": ["test"]})
|
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
|
||||||
|
|
||||||
|
|
||||||
class TestDocumentEmbeddingsMessageCompatibility:
|
class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
|
|
@ -201,7 +217,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
"""Test complete request-response flow maintains data integrity"""
|
"""Test complete request-response flow maintains data integrity"""
|
||||||
# Create request
|
# Create request
|
||||||
request_data = {
|
request_data = {
|
||||||
"vectors": [[0.1, 0.2, 0.3]],
|
"vector": [0.1, 0.2, 0.3],
|
||||||
"limit": 5,
|
"limit": 5,
|
||||||
"user": "test_user",
|
"user": "test_user",
|
||||||
"collection": "test_collection"
|
"collection": "test_collection"
|
||||||
|
|
@ -214,7 +230,10 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
# Simulate service processing and creating response
|
# Simulate service processing and creating response
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=None,
|
error=None,
|
||||||
chunk_ids=["doc1/c1", "doc2/c2"]
|
chunks=[
|
||||||
|
ChunkMatch(chunk_id="doc1/c1", score=0.95),
|
||||||
|
ChunkMatch(chunk_id="doc2/c2", score=0.85)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert response back to dict
|
# Convert response back to dict
|
||||||
|
|
@ -224,8 +243,8 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
# Verify data integrity
|
# Verify data integrity
|
||||||
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
||||||
assert isinstance(response_data, dict)
|
assert isinstance(response_data, dict)
|
||||||
assert "chunk_ids" in response_data
|
assert "chunks" in response_data
|
||||||
assert len(response_data["chunk_ids"]) == 2
|
assert len(response_data["chunks"]) == 2
|
||||||
|
|
||||||
def test_error_response_flow(self):
|
def test_error_response_flow(self):
|
||||||
"""Test error response flow"""
|
"""Test error response flow"""
|
||||||
|
|
@ -237,7 +256,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
|
|
||||||
response = DocumentEmbeddingsResponse(
|
response = DocumentEmbeddingsResponse(
|
||||||
error=error,
|
error=error,
|
||||||
chunk_ids=[]
|
chunks=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert response to dict
|
# Convert response to dict
|
||||||
|
|
@ -246,6 +265,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
|
|
||||||
# Verify error handling
|
# Verify error handling
|
||||||
assert isinstance(response_data, dict)
|
assert isinstance(response_data, dict)
|
||||||
# The translator doesn't include error in the dict, only chunk_ids
|
# The translator doesn't include error in the dict, only chunks
|
||||||
assert "chunk_ids" in response_data
|
assert "chunks" in response_data
|
||||||
assert response_data["chunk_ids"] == []
|
assert response_data["chunks"] == []
|
||||||
|
|
|
||||||
|
|
@ -285,11 +285,11 @@ class TestStructuredEmbeddingsContracts:
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
metadata=[]
|
metadata=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
embedding = StructuredObjectEmbedding(
|
embedding = StructuredObjectEmbedding(
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
schema_name="customer_records",
|
schema_name="customer_records",
|
||||||
object_id="customer_123",
|
object_id="customer_123",
|
||||||
field_embeddings={
|
field_embeddings={
|
||||||
|
|
@ -301,7 +301,7 @@ class TestStructuredEmbeddingsContracts:
|
||||||
# Assert
|
# Assert
|
||||||
assert embedding.schema_name == "customer_records"
|
assert embedding.schema_name == "customer_records"
|
||||||
assert embedding.object_id == "customer_123"
|
assert embedding.object_id == "customer_123"
|
||||||
assert len(embedding.vectors) == 2
|
assert len(embedding.vector) == 3
|
||||||
assert len(embedding.field_embeddings) == 2
|
assert len(embedding.field_embeddings) == 2
|
||||||
assert "name" in embedding.field_embeddings
|
assert "name" in embedding.field_embeddings
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing.
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
from trustgraph.schema import ChunkMatch
|
||||||
|
|
||||||
|
|
||||||
# Sample chunk content for testing - maps chunk_id to content
|
# Sample chunk content for testing - maps chunk_id to content
|
||||||
|
|
@ -39,10 +40,14 @@ class TestDocumentRagIntegration:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_doc_embeddings_client(self):
|
def mock_doc_embeddings_client(self):
|
||||||
"""Mock document embeddings client that returns chunk IDs"""
|
"""Mock document embeddings client that returns chunk matches"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
# Now returns chunk_ids instead of actual content
|
# Returns ChunkMatch objects with chunk_id and score
|
||||||
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
|
client.query.return_value = [
|
||||||
|
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||||
|
ChunkMatch(chunk_id="doc/c2", score=0.90),
|
||||||
|
ChunkMatch(chunk_id="doc/c3", score=0.85)
|
||||||
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -97,7 +102,7 @@ class TestDocumentRagIntegration:
|
||||||
mock_embeddings_client.embed.assert_called_once_with([query])
|
mock_embeddings_client.embed.assert_called_once_with([query])
|
||||||
|
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||||
limit=doc_limit,
|
limit=doc_limit,
|
||||||
user=user,
|
user=user,
|
||||||
collection=collection
|
collection=collection
|
||||||
|
|
@ -298,7 +303,7 @@ class TestDocumentRagIntegration:
|
||||||
assert "DocumentRag initialized" in log_messages
|
assert "DocumentRag initialized" in log_messages
|
||||||
assert "Constructing prompt..." in log_messages
|
assert "Constructing prompt..." in log_messages
|
||||||
assert "Computing embeddings..." in log_messages
|
assert "Computing embeddings..." in log_messages
|
||||||
assert "chunk_ids" in log_messages.lower()
|
assert "chunks" in log_messages.lower()
|
||||||
assert "Invoking LLM..." in log_messages
|
assert "Invoking LLM..." in log_messages
|
||||||
assert "Query processing complete" in log_messages
|
assert "Query processing complete" in log_messages
|
||||||
|
|
||||||
|
|
@ -307,9 +312,9 @@ class TestDocumentRagIntegration:
|
||||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||||
mock_doc_embeddings_client):
|
mock_doc_embeddings_client):
|
||||||
"""Test DocumentRAG performance with large document retrieval"""
|
"""Test DocumentRAG performance with large document retrieval"""
|
||||||
# Arrange - Mock large chunk_id set (100 chunks)
|
# Arrange - Mock large chunk match set (100 chunks)
|
||||||
large_chunk_ids = [f"doc/c{i}" for i in range(100)]
|
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
|
||||||
mock_doc_embeddings_client.query.return_value = large_chunk_ids
|
mock_doc_embeddings_client.query.return_value = large_chunk_matches
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
import time
|
import time
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ response delivery through the complete pipeline.
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
from trustgraph.schema import ChunkMatch
|
||||||
from tests.utils.streaming_assertions import (
|
from tests.utils.streaming_assertions import (
|
||||||
assert_streaming_chunks_valid,
|
assert_streaming_chunks_valid,
|
||||||
assert_callback_invoked,
|
assert_callback_invoked,
|
||||||
|
|
@ -36,10 +37,14 @@ class TestDocumentRagStreaming:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_doc_embeddings_client(self):
|
def mock_doc_embeddings_client(self):
|
||||||
"""Mock document embeddings client that returns chunk IDs"""
|
"""Mock document embeddings client that returns chunk matches"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
# Now returns chunk_ids instead of actual content
|
# Returns ChunkMatch objects with chunk_id and score
|
||||||
client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"]
|
client.query.return_value = [
|
||||||
|
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||||
|
ChunkMatch(chunk_id="doc/c2", score=0.90),
|
||||||
|
ChunkMatch(chunk_id="doc/c3", score=0.85)
|
||||||
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||||
|
from trustgraph.schema import EntityMatch, Term, IRI
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
|
|
@ -35,9 +36,9 @@ class TestGraphRagIntegration:
|
||||||
"""Mock graph embeddings client that returns realistic entities"""
|
"""Mock graph embeddings client that returns realistic entities"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.query.return_value = [
|
client.query.return_value = [
|
||||||
"http://trustgraph.ai/e/machine-learning",
|
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||||
"http://trustgraph.ai/e/artificial-intelligence",
|
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
|
||||||
"http://trustgraph.ai/e/neural-networks"
|
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
|
||||||
]
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
@ -130,7 +131,7 @@ class TestGraphRagIntegration:
|
||||||
# 2. Should query graph embeddings to find relevant entities
|
# 2. Should query graph embeddings to find relevant entities
|
||||||
mock_graph_embeddings_client.query.assert_called_once()
|
mock_graph_embeddings_client.query.assert_called_once()
|
||||||
call_args = mock_graph_embeddings_client.query.call_args
|
call_args = mock_graph_embeddings_client.query.call_args
|
||||||
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
assert call_args.kwargs['limit'] == entity_limit
|
assert call_args.kwargs['limit'] == entity_limit
|
||||||
assert call_args.kwargs['user'] == user
|
assert call_args.kwargs['user'] == user
|
||||||
assert call_args.kwargs['collection'] == collection
|
assert call_args.kwargs['collection'] == collection
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ response delivery through the complete pipeline.
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||||
|
from trustgraph.schema import EntityMatch, Term, IRI
|
||||||
from tests.utils.streaming_assertions import (
|
from tests.utils.streaming_assertions import (
|
||||||
assert_streaming_chunks_valid,
|
assert_streaming_chunks_valid,
|
||||||
assert_rag_streaming_chunks,
|
assert_rag_streaming_chunks,
|
||||||
|
|
@ -33,7 +34,7 @@ class TestGraphRagStreaming:
|
||||||
"""Mock graph embeddings client"""
|
"""Mock graph embeddings client"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.query.return_value = [
|
client.query.return_value = [
|
||||||
"http://trustgraph.ai/e/machine-learning",
|
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||||
]
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -411,7 +411,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
||||||
entities=[
|
entities=[
|
||||||
EntityEmbeddings(
|
EntityEmbeddings(
|
||||||
entity=Term(type=IRI, iri="http://example.org/entity"),
|
entity=Term(type=IRI, iri="http://example.org/entity"),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, call
|
from unittest.mock import AsyncMock, MagicMock, call
|
||||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||||
|
|
||||||
|
|
||||||
class TestGraphRagStreamingProtocol:
|
class TestGraphRagStreamingProtocol:
|
||||||
|
|
@ -25,7 +26,10 @@ class TestGraphRagStreamingProtocol:
|
||||||
def mock_graph_embeddings_client(self):
|
def mock_graph_embeddings_client(self):
|
||||||
"""Mock graph embeddings client"""
|
"""Mock graph embeddings client"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.query.return_value = ["entity1", "entity2"]
|
client.query.return_value = [
|
||||||
|
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
|
||||||
|
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
|
||||||
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -202,9 +206,12 @@ class TestDocumentRagStreamingProtocol:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_doc_embeddings_client(self):
|
def mock_doc_embeddings_client(self):
|
||||||
"""Mock document embeddings client that returns chunk IDs"""
|
"""Mock document embeddings client that returns chunk matches"""
|
||||||
client = AsyncMock()
|
client = AsyncMock()
|
||||||
client.query.return_value = ["doc/c1", "doc/c2"]
|
client.query.return_value = [
|
||||||
|
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||||
|
ChunkMatch(chunk_id="doc/c2", score=0.90)
|
||||||
|
]
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"]
|
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
# Mock the request method
|
# Mock the request method
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await client.query(
|
result = await client.query(
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
limit=10,
|
limit=10,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
timeout=30
|
timeout=30
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||||
client.request.assert_called_once()
|
client.request.assert_called_once()
|
||||||
call_args = client.request.call_args[0][0]
|
call_args = client.request.call_args[0][0]
|
||||||
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
||||||
assert call_args.vectors == vectors
|
assert call_args.vector == vector
|
||||||
assert call_args.limit == 10
|
assert call_args.limit == 10
|
||||||
assert call_args.user == "test_user"
|
assert call_args.user == "test_user"
|
||||||
assert call_args.collection == "test_collection"
|
assert call_args.collection == "test_collection"
|
||||||
|
|
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
with pytest.raises(RuntimeError, match="Database connection failed"):
|
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||||
await client.query(
|
await client.query(
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunk_ids = []
|
mock_response.chunks = []
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|
@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunk_ids = ["test_chunk"]
|
mock_response.chunks = ["test_chunk"]
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
client.request.assert_called_once()
|
client.request.assert_called_once()
|
||||||
|
|
@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunk_ids = ["chunk1"]
|
mock_response.chunks = ["chunk1"]
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await client.query(
|
await client.query(
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
timeout=60
|
timeout=60
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert client.request.call_args[1]["timeout"] == 60
|
assert client.request.call_args[1]["timeout"] == 60
|
||||||
|
|
||||||
|
|
@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
client = DocumentEmbeddingsClient()
|
client = DocumentEmbeddingsClient()
|
||||||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
mock_response.error = None
|
mock_response.error = None
|
||||||
mock_response.chunk_ids = ["test_chunk"]
|
mock_response.chunks = ["test_chunk"]
|
||||||
|
|
||||||
client.request = AsyncMock(return_value=mock_response)
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
|
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
|
||||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
mock_logger.debug.assert_called_once()
|
mock_logger.debug.assert_called_once()
|
||||||
assert "Document embeddings response" in str(mock_logger.debug.call_args)
|
assert "Document embeddings response" in str(mock_logger.debug.call_args)
|
||||||
|
|
|
||||||
|
|
@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||||
client.call = MagicMock(return_value=mock_response)
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = client.request(
|
result = client.request(
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
limit=10,
|
limit=10,
|
||||||
timeout=300
|
timeout=300
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||||
client.call.assert_called_once_with(
|
client.call.assert_called_once_with(
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
limit=10,
|
limit=10,
|
||||||
timeout=300
|
timeout=300
|
||||||
)
|
)
|
||||||
|
|
@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.chunks = ["test_chunk"]
|
mock_response.chunks = ["test_chunk"]
|
||||||
client.call = MagicMock(return_value=mock_response)
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = client.request(vectors=vectors)
|
result = client.request(vector=vector)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == ["test_chunk"]
|
assert result == ["test_chunk"]
|
||||||
client.call.assert_called_once_with(
|
client.call.assert_called_once_with(
|
||||||
user="trustgraph",
|
user="trustgraph",
|
||||||
collection="default",
|
collection="default",
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
limit=10,
|
limit=10,
|
||||||
timeout=300
|
timeout=300
|
||||||
)
|
)
|
||||||
|
|
@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.chunks = []
|
mock_response.chunks = []
|
||||||
client.call = MagicMock(return_value=mock_response)
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = client.request(vectors=[[0.1, 0.2, 0.3]])
|
result = client.request(vector=[0.1, 0.2, 0.3])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|
@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.chunks = None
|
mock_response.chunks = None
|
||||||
client.call = MagicMock(return_value=mock_response)
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = client.request(vectors=[[0.1, 0.2, 0.3]])
|
result = client.request(vector=[0.1, 0.2, 0.3])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.chunks = ["chunk1"]
|
mock_response.chunks = ["chunk1"]
|
||||||
client.call = MagicMock(return_value=mock_response)
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
client.request(
|
client.request(
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
timeout=600
|
timeout=600
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert client.call.call_args[1]["timeout"] == 600
|
assert client.call.call_args[1]["timeout"] == 600
|
||||||
|
|
@ -98,7 +98,7 @@ def sample_graph_embeddings():
|
||||||
entities=[
|
entities=[
|
||||||
EntityEmbeddings(
|
EntityEmbeddings(
|
||||||
entity=Term(type=IRI, iri="http://example.org/john"),
|
entity=Term(type=IRI, iri="http://example.org/john"),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
# Assert
|
# Assert
|
||||||
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
||||||
assert processor.cached_model_name == "test-model" # Still using default
|
assert processor.cached_model_name == "test-model" # Still using default
|
||||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
@patch('trustgraph.embeddings.fastembed.processor.TextEmbedding')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
model="test-model",
|
model="test-model",
|
||||||
input=["test text"]
|
input=["test text"]
|
||||||
)
|
)
|
||||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
@ -86,7 +86,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
model="custom-model",
|
model="custom-model",
|
||||||
input=["test text"]
|
input=["test text"]
|
||||||
)
|
)
|
||||||
assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from trustgraph.query.doc_embeddings.milvus.service import Processor
|
from trustgraph.query.doc_embeddings.milvus.service import Processor
|
||||||
from trustgraph.schema import DocumentEmbeddingsRequest
|
from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
|
||||||
|
|
||||||
|
|
||||||
class TestMilvusDocEmbeddingsQueryProcessor:
|
class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
return query
|
return query
|
||||||
|
|
@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
|
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify results are document chunks
|
# Verify results are ChunkMatch objects
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert result[0] == "First document chunk"
|
assert isinstance(result[0], ChunkMatch)
|
||||||
assert result[1] == "Second document chunk"
|
assert result[0].chunk_id == "First document chunk"
|
||||||
assert result[2] == "Third document chunk"
|
assert result[1].chunk_id == "Second document chunk"
|
||||||
|
assert result[2].chunk_id == "Third document chunk"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_document_embeddings_multiple_vectors(self, processor):
|
async def test_query_document_embeddings_longer_vector(self, processor):
|
||||||
"""Test querying document embeddings with multiple vectors"""
|
"""Test querying document embeddings with a longer vector"""
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
limit=3
|
limit=3
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock search results - different results for each vector
|
# Mock search results
|
||||||
mock_results_1 = [
|
mock_results = [
|
||||||
{"entity": {"chunk_id": "Document from first vector"}},
|
{"entity": {"chunk_id": "First document"}},
|
||||||
{"entity": {"chunk_id": "Another doc from first vector"}},
|
{"entity": {"chunk_id": "Second document"}},
|
||||||
|
{"entity": {"chunk_id": "Third document"}},
|
||||||
]
|
]
|
||||||
mock_results_2 = [
|
processor.vecstore.search.return_value = mock_results
|
||||||
{"entity": {"chunk_id": "Document from second vector"}},
|
|
||||||
]
|
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
|
||||||
|
|
||||||
result = await processor.query_document_embeddings(query)
|
result = await processor.query_document_embeddings(query)
|
||||||
|
|
||||||
# Verify search was called twice with correct parameters including user/collection
|
# Verify search was called once with the full vector
|
||||||
expected_calls = [
|
processor.vecstore.search.assert_called_once_with(
|
||||||
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}),
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
|
||||||
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}),
|
)
|
||||||
]
|
|
||||||
assert processor.vecstore.search.call_count == 2
|
# Verify results are ChunkMatch objects
|
||||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
|
||||||
actual_call = processor.vecstore.search.call_args_list[i]
|
|
||||||
assert actual_call[0] == expected_args
|
|
||||||
assert actual_call[1] == expected_kwargs
|
|
||||||
|
|
||||||
# Verify results from all vectors are combined
|
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert "Document from first vector" in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert "Another doc from first vector" in result
|
assert "First document" in chunk_ids
|
||||||
assert "Document from second vector" in result
|
assert "Second document" in chunk_ids
|
||||||
|
assert "Third document" in chunk_ids
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_document_embeddings_with_limit(self, processor):
|
async def test_query_document_embeddings_with_limit(self, processor):
|
||||||
|
|
@ -141,7 +135,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=2
|
limit=2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[],
|
vector=[],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -211,7 +205,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -225,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_document_embeddings(query)
|
result = await processor.query_document_embeddings(query)
|
||||||
|
|
||||||
# Verify Unicode content is preserved
|
# Verify Unicode content is preserved in ChunkMatch objects
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert "Document with Unicode: éñ中文🚀" in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert "Regular ASCII document" in result
|
assert "Document with Unicode: éñ中文🚀" in chunk_ids
|
||||||
assert "Document with émojis: 😀🎉" in result
|
assert "Regular ASCII document" in chunk_ids
|
||||||
|
assert "Document with émojis: 😀🎉" in chunk_ids
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_document_embeddings_large_documents(self, processor):
|
async def test_query_document_embeddings_large_documents(self, processor):
|
||||||
|
|
@ -237,7 +232,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -251,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_document_embeddings(query)
|
result = await processor.query_document_embeddings(query)
|
||||||
|
|
||||||
# Verify large content is preserved
|
# Verify large content is preserved in ChunkMatch objects
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert large_doc in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert "Small document" in result
|
assert large_doc in chunk_ids
|
||||||
|
assert "Small document" in chunk_ids
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_document_embeddings_special_characters(self, processor):
|
async def test_query_document_embeddings_special_characters(self, processor):
|
||||||
|
|
@ -262,7 +258,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -276,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
result = await processor.query_document_embeddings(query)
|
result = await processor.query_document_embeddings(query)
|
||||||
|
|
||||||
# Verify special characters are preserved
|
# Verify special characters are preserved in ChunkMatch objects
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert "Document with \"quotes\" and 'apostrophes'" in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert "Document with\nnewlines\tand\ttabs" in result
|
assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
|
||||||
assert "Document with special chars: @#$%^&*()" in result
|
assert "Document with\nnewlines\tand\ttabs" in chunk_ids
|
||||||
|
assert "Document with special chars: @#$%^&*()" in chunk_ids
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||||
|
|
@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=0
|
limit=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=-1
|
limit=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
|
||||||
[0.1, 0.2], # 2D vector
|
|
||||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
|
||||||
[0.7, 0.8, 0.9] # 3D vector
|
|
||||||
],
|
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock search results for each vector
|
# Mock search results
|
||||||
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}]
|
mock_results = [
|
||||||
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}]
|
{"entity": {"chunk_id": "Document 1"}},
|
||||||
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}]
|
{"entity": {"chunk_id": "Document 2"}},
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
]
|
||||||
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
result = await processor.query_document_embeddings(query)
|
result = await processor.query_document_embeddings(query)
|
||||||
|
|
||||||
# Verify all vectors were searched
|
# Verify search was called with the vector
|
||||||
assert processor.vecstore.search.call_count == 3
|
processor.vecstore.search.assert_called_once()
|
||||||
|
|
||||||
# Verify results from all dimensions
|
# Verify results are ChunkMatch objects
|
||||||
assert len(result) == 3
|
assert len(result) == 2
|
||||||
assert "Document from 2D vector" in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert "Document from 4D vector" in result
|
assert "Document 1" in chunk_ids
|
||||||
assert "Document from 3D vector" in result
|
assert "Document 2" in chunk_ids
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_document_embeddings_duplicate_documents(self, processor):
|
async def test_query_document_embeddings_multiple_results(self, processor):
|
||||||
"""Test querying document embeddings with duplicate documents in results"""
|
"""Test querying document embeddings with multiple results"""
|
||||||
query = DocumentEmbeddingsRequest(
|
query = DocumentEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock search results with duplicates across vectors
|
# Mock search results with multiple documents
|
||||||
mock_results_1 = [
|
mock_results = [
|
||||||
{"entity": {"chunk_id": "Document A"}},
|
{"entity": {"chunk_id": "Document A"}},
|
||||||
{"entity": {"chunk_id": "Document B"}},
|
{"entity": {"chunk_id": "Document B"}},
|
||||||
]
|
|
||||||
mock_results_2 = [
|
|
||||||
{"entity": {"chunk_id": "Document B"}}, # Duplicate
|
|
||||||
{"entity": {"chunk_id": "Document C"}},
|
{"entity": {"chunk_id": "Document C"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
result = await processor.query_document_embeddings(query)
|
result = await processor.query_document_embeddings(query)
|
||||||
|
|
||||||
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
|
# Verify results are ChunkMatch objects
|
||||||
# This preserves ranking and allows multiple occurrences
|
assert len(result) == 3
|
||||||
assert len(result) == 4
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert result.count("Document B") == 2 # Should appear twice
|
assert "Document A" in chunk_ids
|
||||||
assert "Document A" in result
|
assert "Document B" in chunk_ids
|
||||||
assert "Document C" in result
|
assert "Document C" in chunk_ids
|
||||||
|
|
||||||
def test_add_args_method(self):
|
def test_add_args_method(self):
|
||||||
"""Test that add_args properly configures argument parser"""
|
"""Test that add_args properly configures argument parser"""
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_single_vector(self, processor):
|
async def test_query_document_embeddings_single_vector(self, processor):
|
||||||
"""Test querying document embeddings with a single vector"""
|
"""Test querying document embeddings with a single vector"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 3
|
message.limit = 3
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_limit_handling(self, processor):
|
async def test_query_document_embeddings_limit_handling(self, processor):
|
||||||
"""Test that query respects the limit parameter"""
|
"""Test that query respects the limit parameter"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 2
|
message.limit = 2
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -208,7 +208,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||||
"""Test querying with zero limit returns empty results"""
|
"""Test querying with zero limit returns empty results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 0
|
message.limit = 0
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -226,7 +226,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_negative_limit(self, processor):
|
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||||
"""Test querying with negative limit returns empty results"""
|
"""Test querying with negative limit returns empty results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = -1
|
message.limit = -1
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -285,7 +285,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_empty_vectors_list(self, processor):
|
async def test_query_document_embeddings_empty_vectors_list(self, processor):
|
||||||
"""Test querying with empty vectors list"""
|
"""Test querying with empty vectors list"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = []
|
message.vector = []
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -304,7 +304,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_no_results(self, processor):
|
async def test_query_document_embeddings_no_results(self, processor):
|
||||||
"""Test querying when index returns no results"""
|
"""Test querying when index returns no results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -325,7 +325,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_unicode_content(self, processor):
|
async def test_query_document_embeddings_unicode_content(self, processor):
|
||||||
"""Test querying document embeddings with Unicode content results"""
|
"""Test querying document embeddings with Unicode content results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 2
|
message.limit = 2
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -351,7 +351,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_large_content(self, processor):
|
async def test_query_document_embeddings_large_content(self, processor):
|
||||||
"""Test querying document embeddings with large content results"""
|
"""Test querying document embeddings with large content results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 1
|
message.limit = 1
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -377,7 +377,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_mixed_content_types(self, processor):
|
async def test_query_document_embeddings_mixed_content_types(self, processor):
|
||||||
"""Test querying document embeddings with mixed content types"""
|
"""Test querying document embeddings with mixed content types"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -409,7 +409,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_exception_handling(self, processor):
|
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||||
"""Test that exceptions are properly raised"""
|
"""Test that exceptions are properly raised"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -425,7 +425,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
||||||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||||
"""Test handling of index access failure"""
|
"""Test handling of index access failure"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
||||||
|
|
||||||
# Import the service under test
|
# Import the service under test
|
||||||
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
||||||
|
from trustgraph.schema import ChunkMatch
|
||||||
|
|
||||||
|
|
||||||
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
mock_message.vector = [0.1, 0.2, 0.3]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'test_user'
|
mock_message.user = 'test_user'
|
||||||
mock_message.collection = 'test_collection'
|
mock_message.collection = 'test_collection'
|
||||||
|
|
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
with_payload=True
|
with_payload=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result contains expected documents
|
# Verify result contains expected ChunkMatch objects
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
# Results should be strings (document chunks)
|
# Results should be ChunkMatch objects
|
||||||
assert isinstance(result[0], str)
|
assert isinstance(result[0], ChunkMatch)
|
||||||
assert isinstance(result[1], str)
|
assert isinstance(result[1], ChunkMatch)
|
||||||
# Verify content
|
# Verify content
|
||||||
assert result[0] == 'first document chunk'
|
assert result[0].chunk_id == 'first document chunk'
|
||||||
assert result[1] == 'second document chunk'
|
assert result[1].chunk_id == 'second document chunk'
|
||||||
|
|
||||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||||
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
|
||||||
"""Test querying document embeddings with multiple vectors"""
|
"""Test querying document embeddings returns multiple results"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
mock_base_init.return_value = None
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
# Mock query responses for different vectors
|
# Mock query response with multiple results
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'chunk_id': 'document from vector 1'}
|
mock_point1.payload = {'chunk_id': 'document chunk 1'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {'chunk_id': 'document from vector 2'}
|
mock_point2.payload = {'chunk_id': 'document chunk 2'}
|
||||||
mock_point3 = MagicMock()
|
mock_point3 = MagicMock()
|
||||||
mock_point3.payload = {'chunk_id': 'another document from vector 2'}
|
mock_point3.payload = {'chunk_id': 'document chunk 3'}
|
||||||
|
|
||||||
mock_response1 = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response1.points = [mock_point1]
|
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||||
mock_response2 = MagicMock()
|
mock_qdrant_instance.query_points.return_value = mock_response
|
||||||
mock_response2.points = [mock_point2, mock_point3]
|
|
||||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'taskgroup': AsyncMock(),
|
'taskgroup': AsyncMock(),
|
||||||
'id': 'test-processor'
|
'id': 'test-processor'
|
||||||
}
|
}
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Create mock message with multiple vectors
|
# Create mock message with single vector
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 3
|
mock_message.limit = 3
|
||||||
mock_message.user = 'multi_user'
|
mock_message.user = 'multi_user'
|
||||||
mock_message.collection = 'multi_collection'
|
mock_message.collection = 'multi_collection'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await processor.query_document_embeddings(mock_message)
|
result = await processor.query_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify query was called twice
|
# Verify query was called once
|
||||||
assert mock_qdrant_instance.query_points.call_count == 2
|
assert mock_qdrant_instance.query_points.call_count == 1
|
||||||
|
|
||||||
# Verify both collections were queried (both 2-dimensional vectors)
|
# Verify collection was queried correctly
|
||||||
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
|
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
|
||||||
calls = mock_qdrant_instance.query_points.call_args_list
|
calls = mock_qdrant_instance.query_points.call_args_list
|
||||||
assert calls[0][1]['collection_name'] == expected_collection
|
assert calls[0][1]['collection_name'] == expected_collection
|
||||||
assert calls[1][1]['collection_name'] == expected_collection
|
|
||||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
|
||||||
|
# Verify results are ChunkMatch objects
|
||||||
# Verify results from both vectors are combined
|
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert 'document from vector 1' in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert 'document from vector 2' in result
|
assert 'document chunk 1' in chunk_ids
|
||||||
assert 'another document from vector 2' in result
|
assert 'document chunk 2' in chunk_ids
|
||||||
|
assert 'document chunk 3' in chunk_ids
|
||||||
|
|
||||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message with limit
|
# Create mock message with limit
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
mock_message.vector = [0.1, 0.2, 0.3]
|
||||||
mock_message.limit = 3 # Should only return 3 results
|
mock_message.limit = 3 # Should only return 3 results
|
||||||
mock_message.user = 'limit_user'
|
mock_message.user = 'limit_user'
|
||||||
mock_message.collection = 'limit_collection'
|
mock_message.collection = 'limit_collection'
|
||||||
|
|
@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'empty_user'
|
mock_message.user = 'empty_user'
|
||||||
mock_message.collection = 'empty_collection'
|
mock_message.collection = 'empty_collection'
|
||||||
|
|
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||||
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||||
"""Test querying document embeddings with different vector dimensions"""
|
"""Test querying document embeddings with a higher dimension vector"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
mock_base_init.return_value = None
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
# Mock query responses
|
# Mock query response
|
||||||
mock_point1 = MagicMock()
|
mock_point1 = MagicMock()
|
||||||
mock_point1.payload = {'chunk_id': 'document from 2D vector'}
|
mock_point1.payload = {'chunk_id': 'document from 5D vector'}
|
||||||
mock_point2 = MagicMock()
|
mock_point2 = MagicMock()
|
||||||
mock_point2.payload = {'chunk_id': 'document from 3D vector'}
|
mock_point2.payload = {'chunk_id': 'another 5D document'}
|
||||||
|
|
||||||
mock_response1 = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response1.points = [mock_point1]
|
mock_response.points = [mock_point1, mock_point2]
|
||||||
mock_response2 = MagicMock()
|
mock_qdrant_instance.query_points.return_value = mock_response
|
||||||
mock_response2.points = [mock_point2]
|
|
||||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'taskgroup': AsyncMock(),
|
'taskgroup': AsyncMock(),
|
||||||
'id': 'test-processor'
|
'id': 'test-processor'
|
||||||
}
|
}
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Create mock message with different dimension vectors
|
# Create mock message with 5D vector
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'dim_user'
|
mock_message.user = 'dim_user'
|
||||||
mock_message.collection = 'dim_collection'
|
mock_message.collection = 'dim_collection'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await processor.query_document_embeddings(mock_message)
|
result = await processor.query_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify query was called twice with different collections
|
# Verify query was called once with correct collection
|
||||||
assert mock_qdrant_instance.query_points.call_count == 2
|
assert mock_qdrant_instance.query_points.call_count == 1
|
||||||
calls = mock_qdrant_instance.query_points.call_args_list
|
calls = mock_qdrant_instance.query_points.call_args_list
|
||||||
|
|
||||||
# First call should use 2D collection
|
# Call should use 5D collection
|
||||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
|
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
|
||||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|
||||||
# Second call should use 3D collection
|
# Verify results are ChunkMatch objects
|
||||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
|
|
||||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert 'document from 2D vector' in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert 'document from 3D vector' in result
|
assert 'document from 5D vector' in chunk_ids
|
||||||
|
assert 'another 5D document' in chunk_ids
|
||||||
|
|
||||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'utf8_user'
|
mock_message.user = 'utf8_user'
|
||||||
mock_message.collection = 'utf8_collection'
|
mock_message.collection = 'utf8_collection'
|
||||||
|
|
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
# Verify UTF-8 content works correctly
|
# Verify UTF-8 content works correctly in ChunkMatch objects
|
||||||
assert 'Document with UTF-8: café, naïve, résumé' in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert 'Chinese text: 你好世界' in result
|
assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
|
||||||
|
assert 'Chinese text: 你好世界' in chunk_ids
|
||||||
|
|
||||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'error_user'
|
mock_message.user = 'error_user'
|
||||||
mock_message.collection = 'error_collection'
|
mock_message.collection = 'error_collection'
|
||||||
|
|
@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message with zero limit
|
# Create mock message with zero limit
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 0
|
mock_message.limit = 0
|
||||||
mock_message.user = 'zero_user'
|
mock_message.user = 'zero_user'
|
||||||
mock_message.collection = 'zero_collection'
|
mock_message.collection = 'zero_collection'
|
||||||
|
|
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_instance.query_points.assert_called_once()
|
mock_qdrant_instance.query_points.assert_called_once()
|
||||||
call_args = mock_qdrant_instance.query_points.call_args
|
call_args = mock_qdrant_instance.query_points.call_args
|
||||||
assert call_args[1]['limit'] == 0
|
assert call_args[1]['limit'] == 0
|
||||||
|
|
||||||
# Result should contain all returned documents
|
# Result should contain all returned documents as ChunkMatch objects
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0] == 'document chunk'
|
assert isinstance(result[0], ChunkMatch)
|
||||||
|
assert result[0].chunk_id == 'document chunk'
|
||||||
|
|
||||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message with large limit
|
# Create mock message with large limit
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 1000 # Large limit
|
mock_message.limit = 1000 # Large limit
|
||||||
mock_message.user = 'large_user'
|
mock_message.user = 'large_user'
|
||||||
mock_message.collection = 'large_collection'
|
mock_message.collection = 'large_collection'
|
||||||
|
|
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_instance.query_points.assert_called_once()
|
mock_qdrant_instance.query_points.assert_called_once()
|
||||||
call_args = mock_qdrant_instance.query_points.call_args
|
call_args = mock_qdrant_instance.query_points.call_args
|
||||||
assert call_args[1]['limit'] == 1000
|
assert call_args[1]['limit'] == 1000
|
||||||
|
|
||||||
# Result should contain all available documents
|
# Result should contain all available documents as ChunkMatch objects
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert 'document 1' in result
|
chunk_ids = [r.chunk_id for r in result]
|
||||||
assert 'document 2' in result
|
assert 'document 1' in chunk_ids
|
||||||
|
assert 'document 2' in chunk_ids
|
||||||
|
|
||||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -508,7 +504,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'payload_user'
|
mock_message.user = 'payload_user'
|
||||||
mock_message.collection = 'payload_collection'
|
mock_message.collection = 'payload_collection'
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
||||||
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
|
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch
|
||||||
|
|
||||||
|
|
||||||
class TestMilvusGraphEmbeddingsQueryProcessor:
|
class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
|
|
@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
return query
|
return query
|
||||||
|
|
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
|
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify results are converted to Term objects
|
# Verify results are converted to EntityMatch objects
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert isinstance(result[0], Term)
|
assert isinstance(result[0], EntityMatch)
|
||||||
assert result[0].iri == "http://example.com/entity1"
|
assert result[0].entity.iri == "http://example.com/entity1"
|
||||||
assert result[0].type == IRI
|
assert result[0].entity.type == IRI
|
||||||
assert isinstance(result[1], Term)
|
assert isinstance(result[1], EntityMatch)
|
||||||
assert result[1].iri == "http://example.com/entity2"
|
assert result[1].entity.iri == "http://example.com/entity2"
|
||||||
assert result[1].type == IRI
|
assert result[1].entity.type == IRI
|
||||||
assert isinstance(result[2], Term)
|
assert isinstance(result[2], EntityMatch)
|
||||||
assert result[2].value == "literal entity"
|
assert result[2].entity.value == "literal entity"
|
||||||
assert result[2].type == LITERAL
|
assert result[2].entity.type == LITERAL
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
async def test_query_graph_embeddings_multiple_results(self, processor):
|
||||||
"""Test querying graph embeddings with multiple vectors"""
|
"""Test querying graph embeddings returns multiple results"""
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
limit=3
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock search results - different results for each vector
|
# Mock search results with multiple entities
|
||||||
mock_results_1 = [
|
mock_results = [
|
||||||
{"entity": {"entity": "http://example.com/entity1"}},
|
{"entity": {"entity": "http://example.com/entity1"}},
|
||||||
{"entity": {"entity": "http://example.com/entity2"}},
|
{"entity": {"entity": "http://example.com/entity2"}},
|
||||||
]
|
|
||||||
mock_results_2 = [
|
|
||||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
|
||||||
{"entity": {"entity": "http://example.com/entity3"}},
|
{"entity": {"entity": "http://example.com/entity3"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
result = await processor.query_graph_embeddings(query)
|
result = await processor.query_graph_embeddings(query)
|
||||||
|
|
||||||
# Verify search was called twice with correct parameters including user/collection
|
# Verify search was called once with the full vector
|
||||||
expected_calls = [
|
processor.vecstore.search.assert_called_once_with(
|
||||||
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
|
||||||
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
|
)
|
||||||
]
|
|
||||||
assert processor.vecstore.search.call_count == 2
|
# Verify results are EntityMatch objects
|
||||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
|
||||||
actual_call = processor.vecstore.search.call_args_list[i]
|
|
||||||
assert actual_call[0] == expected_args
|
|
||||||
assert actual_call[1] == expected_kwargs
|
|
||||||
|
|
||||||
# Verify results are deduplicated and limited
|
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||||
assert "http://example.com/entity1" in entity_values
|
assert "http://example.com/entity1" in entity_values
|
||||||
assert "http://example.com/entity2" in entity_values
|
assert "http://example.com/entity2" in entity_values
|
||||||
assert "http://example.com/entity3" in entity_values
|
assert "http://example.com/entity3" in entity_values
|
||||||
|
|
@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=2
|
limit=2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_deduplication(self, processor):
|
async def test_query_graph_embeddings_preserves_order(self, processor):
|
||||||
"""Test that duplicate entities are properly deduplicated"""
|
"""Test that query results preserve order from the vector store"""
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock search results with duplicates
|
|
||||||
mock_results_1 = [
|
|
||||||
{"entity": {"entity": "http://example.com/entity1"}},
|
|
||||||
{"entity": {"entity": "http://example.com/entity2"}},
|
|
||||||
]
|
|
||||||
mock_results_2 = [
|
|
||||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
|
||||||
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
|
|
||||||
{"entity": {"entity": "http://example.com/entity3"}}, # New
|
|
||||||
]
|
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
|
||||||
|
|
||||||
result = await processor.query_graph_embeddings(query)
|
|
||||||
|
|
||||||
# Verify duplicates are removed
|
|
||||||
assert len(result) == 3
|
|
||||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
|
||||||
assert len(set(entity_values)) == 3 # All unique
|
|
||||||
assert "http://example.com/entity1" in entity_values
|
|
||||||
assert "http://example.com/entity2" in entity_values
|
|
||||||
assert "http://example.com/entity3" in entity_values
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# Mock search results in specific order
|
||||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
mock_results = [
|
||||||
"""Test that querying stops early when limit is reached"""
|
|
||||||
query = GraphEmbeddingsRequest(
|
|
||||||
user='test_user',
|
|
||||||
collection='test_collection',
|
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
|
||||||
limit=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock search results - first vector returns enough results
|
|
||||||
mock_results_1 = [
|
|
||||||
{"entity": {"entity": "http://example.com/entity1"}},
|
{"entity": {"entity": "http://example.com/entity1"}},
|
||||||
{"entity": {"entity": "http://example.com/entity2"}},
|
{"entity": {"entity": "http://example.com/entity2"}},
|
||||||
{"entity": {"entity": "http://example.com/entity3"}},
|
{"entity": {"entity": "http://example.com/entity3"}},
|
||||||
]
|
]
|
||||||
processor.vecstore.search.return_value = mock_results_1
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
result = await processor.query_graph_embeddings(query)
|
result = await processor.query_graph_embeddings(query)
|
||||||
|
|
||||||
# Verify only first vector was searched (limit reached)
|
# Verify results are in the same order as returned by the store
|
||||||
processor.vecstore.search.assert_called_once_with(
|
assert len(result) == 3
|
||||||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
|
assert result[0].entity.iri == "http://example.com/entity1"
|
||||||
|
assert result[1].entity.iri == "http://example.com/entity2"
|
||||||
|
assert result[2].entity.iri == "http://example.com/entity3"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_graph_embeddings_results_limited(self, processor):
|
||||||
|
"""Test that results are properly limited when store returns more than requested"""
|
||||||
|
query = GraphEmbeddingsRequest(
|
||||||
|
user='test_user',
|
||||||
|
collection='test_collection',
|
||||||
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
limit=2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify results are limited
|
# Mock search results - returns more results than limit
|
||||||
|
mock_results = [
|
||||||
|
{"entity": {"entity": "http://example.com/entity1"}},
|
||||||
|
{"entity": {"entity": "http://example.com/entity2"}},
|
||||||
|
{"entity": {"entity": "http://example.com/entity3"}},
|
||||||
|
]
|
||||||
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
|
result = await processor.query_graph_embeddings(query)
|
||||||
|
|
||||||
|
# Verify search was called with the full vector
|
||||||
|
processor.vecstore.search.assert_called_once_with(
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify results are limited to requested amount
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[],
|
vector=[],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
|
|
||||||
# Verify all results are properly typed
|
# Verify all results are properly typed
|
||||||
assert len(result) == 4
|
assert len(result) == 4
|
||||||
|
|
||||||
# Check URI entities
|
# Check URI entities
|
||||||
uri_results = [r for r in result if r.type == IRI]
|
uri_results = [r for r in result if r.entity.type == IRI]
|
||||||
assert len(uri_results) == 2
|
assert len(uri_results) == 2
|
||||||
uri_values = [r.iri for r in uri_results]
|
uri_values = [r.entity.iri for r in uri_results]
|
||||||
assert "http://example.com/uri_entity" in uri_values
|
assert "http://example.com/uri_entity" in uri_values
|
||||||
assert "https://example.com/another_uri" in uri_values
|
assert "https://example.com/another_uri" in uri_values
|
||||||
|
|
||||||
# Check literal entities
|
# Check literal entities
|
||||||
literal_results = [r for r in result if not r.type == IRI]
|
literal_results = [r for r in result if not r.entity.type == IRI]
|
||||||
assert len(literal_results) == 2
|
assert len(literal_results) == 2
|
||||||
literal_values = [r.value for r in literal_results]
|
literal_values = [r.entity.value for r in literal_results]
|
||||||
assert "literal entity text" in literal_values
|
assert "literal entity text" in literal_values
|
||||||
assert "another literal" in literal_values
|
assert "another literal" in literal_values
|
||||||
|
|
||||||
|
|
@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
limit=0
|
limit=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||||
assert len(result) == 0
|
assert len(result) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
async def test_query_graph_embeddings_longer_vector(self, processor):
|
||||||
"""Test querying graph embeddings with different vector dimensions"""
|
"""Test querying graph embeddings with a longer vector"""
|
||||||
query = GraphEmbeddingsRequest(
|
query = GraphEmbeddingsRequest(
|
||||||
user='test_user',
|
user='test_user',
|
||||||
collection='test_collection',
|
collection='test_collection',
|
||||||
vectors=[
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||||
[0.1, 0.2], # 2D vector
|
|
||||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
|
||||||
[0.7, 0.8, 0.9] # 3D vector
|
|
||||||
],
|
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock search results for each vector
|
# Mock search results
|
||||||
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
|
mock_results = [
|
||||||
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
|
{"entity": {"entity": "http://example.com/entity1"}},
|
||||||
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
|
{"entity": {"entity": "http://example.com/entity2"}},
|
||||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
]
|
||||||
|
processor.vecstore.search.return_value = mock_results
|
||||||
|
|
||||||
result = await processor.query_graph_embeddings(query)
|
result = await processor.query_graph_embeddings(query)
|
||||||
|
|
||||||
# Verify all vectors were searched
|
# Verify search was called once with the full vector
|
||||||
assert processor.vecstore.search.call_count == 3
|
processor.vecstore.search.assert_called_once()
|
||||||
|
|
||||||
# Verify results from all dimensions
|
# Verify results
|
||||||
assert len(result) == 3
|
assert len(result) == 2
|
||||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||||
assert "entity_2d" in entity_values
|
assert "http://example.com/entity1" in entity_values
|
||||||
assert "entity_4d" in entity_values
|
assert "http://example.com/entity2" in entity_values
|
||||||
assert "entity_3d" in entity_values
|
|
||||||
|
|
@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
|
||||||
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
|
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
|
||||||
|
|
||||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||||
from trustgraph.schema import Term, IRI, LITERAL
|
from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
|
||||||
|
|
||||||
|
|
||||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
|
|
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
def mock_query_message(self):
|
def mock_query_message(self):
|
||||||
"""Create a mock query message for testing"""
|
"""Create a mock query message for testing"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
[0.4, 0.5, 0.6]
|
|
||||||
]
|
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||||
"""Test querying graph embeddings with a single vector"""
|
"""Test querying graph embeddings with a single vector"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 3
|
message.limit = 3
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
include_metadata=True
|
include_metadata=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify results
|
# Verify results use EntityMatch structure
|
||||||
assert len(entities) == 3
|
assert len(entities) == 3
|
||||||
assert entities[0].value == 'http://example.org/entity1'
|
assert entities[0].entity.iri == 'http://example.org/entity1'
|
||||||
assert entities[0].type == IRI
|
assert entities[0].entity.type == IRI
|
||||||
assert entities[1].value == 'entity2'
|
assert entities[1].entity.value == 'entity2'
|
||||||
assert entities[1].type == LITERAL
|
assert entities[1].entity.type == LITERAL
|
||||||
assert entities[2].value == 'http://example.org/entity3'
|
assert entities[2].entity.iri == 'http://example.org/entity3'
|
||||||
assert entities[2].type == IRI
|
assert entities[2].entity.type == IRI
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
|
||||||
"""Test querying graph embeddings with multiple vectors"""
|
"""Test basic graph embeddings query"""
|
||||||
# Mock index and query results
|
# Mock index and query results
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
processor.pinecone.Index.return_value = mock_index
|
processor.pinecone.Index.return_value = mock_index
|
||||||
|
|
||||||
# First query results
|
# Query results with distinct entities
|
||||||
mock_results1 = MagicMock()
|
mock_results = MagicMock()
|
||||||
mock_results1.matches = [
|
mock_results.matches = [
|
||||||
MagicMock(metadata={'entity': 'entity1'}),
|
MagicMock(metadata={'entity': 'entity1'}),
|
||||||
MagicMock(metadata={'entity': 'entity2'})
|
MagicMock(metadata={'entity': 'entity2'}),
|
||||||
]
|
|
||||||
|
|
||||||
# Second query results
|
|
||||||
mock_results2 = MagicMock()
|
|
||||||
mock_results2.matches = [
|
|
||||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
|
||||||
MagicMock(metadata={'entity': 'entity3'})
|
MagicMock(metadata={'entity': 'entity3'})
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
mock_index.query.return_value = mock_results
|
||||||
|
|
||||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||||
|
|
||||||
# Verify both queries were made
|
# Verify query was made once
|
||||||
assert mock_index.query.call_count == 2
|
assert mock_index.query.call_count == 1
|
||||||
|
|
||||||
# Verify deduplication occurred
|
# Verify results with EntityMatch structure
|
||||||
entity_values = [e.value for e in entities]
|
entity_values = [e.entity.value for e in entities]
|
||||||
assert len(entity_values) == 3
|
assert len(entity_values) == 3
|
||||||
assert 'entity1' in entity_values
|
assert 'entity1' in entity_values
|
||||||
assert 'entity2' in entity_values
|
assert 'entity2' in entity_values
|
||||||
|
|
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
async def test_query_graph_embeddings_limit_handling(self, processor):
|
async def test_query_graph_embeddings_limit_handling(self, processor):
|
||||||
"""Test that query respects the limit parameter"""
|
"""Test that query respects the limit parameter"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 2
|
message.limit = 2
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||||
"""Test querying with zero limit returns empty results"""
|
"""Test querying with zero limit returns empty results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 0
|
message.limit = 0
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
async def test_query_graph_embeddings_negative_limit(self, processor):
|
async def test_query_graph_embeddings_negative_limit(self, processor):
|
||||||
"""Test querying with negative limit returns empty results"""
|
"""Test querying with negative limit returns empty results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = -1
|
message.limit = -1
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
assert entities == []
|
assert entities == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
async def test_query_graph_embeddings_2d_vector(self, processor):
|
||||||
"""Test querying with vectors of different dimensions using same index"""
|
"""Test querying with a 2D vector"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [
|
message.vector = [0.1, 0.2] # 2D vector
|
||||||
[0.1, 0.2], # 2D vector
|
|
||||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
|
||||||
]
|
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
||||||
# Mock single index that handles all dimensions
|
# Mock index
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
processor.pinecone.Index.return_value = mock_index
|
processor.pinecone.Index.return_value = mock_index
|
||||||
|
|
||||||
# Mock results for different vector queries
|
# Mock results for 2D vector query
|
||||||
mock_results_2d = MagicMock()
|
mock_results = MagicMock()
|
||||||
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||||
|
|
||||||
mock_results_4d = MagicMock()
|
mock_index.query.return_value = mock_results
|
||||||
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
|
|
||||||
|
|
||||||
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
|
|
||||||
|
|
||||||
entities = await processor.query_graph_embeddings(message)
|
entities = await processor.query_graph_embeddings(message)
|
||||||
|
|
||||||
# Verify different indexes used for different dimensions
|
# Verify correct index used for 2D vector
|
||||||
assert processor.pinecone.Index.call_count == 2
|
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
|
||||||
index_calls = processor.pinecone.Index.call_args_list
|
|
||||||
index_names = [call[0][0] for call in index_calls]
|
|
||||||
assert "t-test_user-test_collection-2" in index_names # 2D vector
|
|
||||||
assert "t-test_user-test_collection-4" in index_names # 4D vector
|
|
||||||
|
|
||||||
# Verify both queries were made
|
# Verify query was made
|
||||||
assert mock_index.query.call_count == 2
|
assert mock_index.query.call_count == 1
|
||||||
|
|
||||||
# Verify results from both dimensions
|
# Verify results with EntityMatch structure
|
||||||
entity_values = [e.value for e in entities]
|
entity_values = [e.entity.value for e in entities]
|
||||||
assert 'entity_2d' in entity_values
|
assert 'entity_2d' in entity_values
|
||||||
assert 'entity_4d' in entity_values
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
|
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
|
||||||
"""Test querying with empty vectors list"""
|
"""Test querying with empty vectors list"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = []
|
message.vector = []
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
async def test_query_graph_embeddings_no_results(self, processor):
|
async def test_query_graph_embeddings_no_results(self, processor):
|
||||||
"""Test querying when index returns no results"""
|
"""Test querying when index returns no results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
assert entities == []
|
assert entities == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
|
async def test_query_graph_embeddings_deduplication_in_results(self, processor):
|
||||||
"""Test that deduplication works correctly across multiple vector queries"""
|
"""Test that deduplication works correctly within query results"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
[0.4, 0.5, 0.6]
|
|
||||||
]
|
|
||||||
message.limit = 3
|
message.limit = 3
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
processor.pinecone.Index.return_value = mock_index
|
processor.pinecone.Index.return_value = mock_index
|
||||||
|
|
||||||
# Both queries return overlapping results
|
# Query returns results with some duplicates
|
||||||
mock_results1 = MagicMock()
|
mock_results = MagicMock()
|
||||||
mock_results1.matches = [
|
mock_results.matches = [
|
||||||
MagicMock(metadata={'entity': 'entity1'}),
|
MagicMock(metadata={'entity': 'entity1'}),
|
||||||
MagicMock(metadata={'entity': 'entity2'}),
|
MagicMock(metadata={'entity': 'entity2'}),
|
||||||
|
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
|
||||||
MagicMock(metadata={'entity': 'entity3'}),
|
MagicMock(metadata={'entity': 'entity3'}),
|
||||||
MagicMock(metadata={'entity': 'entity4'})
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_results2 = MagicMock()
|
|
||||||
mock_results2.matches = [
|
|
||||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||||
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
|
|
||||||
MagicMock(metadata={'entity': 'entity5'})
|
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
mock_index.query.return_value = mock_results
|
||||||
|
|
||||||
entities = await processor.query_graph_embeddings(message)
|
entities = await processor.query_graph_embeddings(message)
|
||||||
|
|
||||||
# Should get exactly 3 unique entities (respecting limit)
|
# Should get exactly 3 unique entities (respecting limit)
|
||||||
assert len(entities) == 3
|
assert len(entities) == 3
|
||||||
entity_values = [e.value for e in entities]
|
entity_values = [e.entity.value for e in entities]
|
||||||
assert len(set(entity_values)) == 3 # All unique
|
assert len(set(entity_values)) == 3 # All unique
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
async def test_query_graph_embeddings_respects_limit(self, processor):
|
||||||
"""Test that querying stops early when limit is reached"""
|
"""Test that query respects limit parameter"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
[0.4, 0.5, 0.6],
|
|
||||||
[0.7, 0.8, 0.9]
|
|
||||||
]
|
|
||||||
message.limit = 2
|
message.limit = 2
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
processor.pinecone.Index.return_value = mock_index
|
processor.pinecone.Index.return_value = mock_index
|
||||||
|
|
||||||
# First query returns enough results to meet limit
|
# Query returns more results than limit
|
||||||
mock_results1 = MagicMock()
|
mock_results = MagicMock()
|
||||||
mock_results1.matches = [
|
mock_results.matches = [
|
||||||
MagicMock(metadata={'entity': 'entity1'}),
|
MagicMock(metadata={'entity': 'entity1'}),
|
||||||
MagicMock(metadata={'entity': 'entity2'}),
|
MagicMock(metadata={'entity': 'entity2'}),
|
||||||
MagicMock(metadata={'entity': 'entity3'})
|
MagicMock(metadata={'entity': 'entity3'})
|
||||||
]
|
]
|
||||||
mock_index.query.return_value = mock_results1
|
mock_index.query.return_value = mock_results
|
||||||
|
|
||||||
entities = await processor.query_graph_embeddings(message)
|
entities = await processor.query_graph_embeddings(message)
|
||||||
|
|
||||||
# Should only make one query since limit was reached
|
# Should only return 2 entities (respecting limit)
|
||||||
mock_index.query.assert_called_once()
|
mock_index.query.assert_called_once()
|
||||||
assert len(entities) == 2
|
assert len(entities) == 2
|
||||||
|
|
||||||
|
|
@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||||
"""Test that exceptions are properly raised"""
|
"""Test that exceptions are properly raised"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.vectors = [[0.1, 0.2, 0.3]]
|
message.vector = [0.1, 0.2, 0.3]
|
||||||
message.limit = 5
|
message.limit = 5
|
||||||
message.user = 'test_user'
|
message.user = 'test_user'
|
||||||
message.collection = 'test_collection'
|
message.collection = 'test_collection'
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
||||||
|
|
||||||
# Import the service under test
|
# Import the service under test
|
||||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||||
from trustgraph.schema import IRI, LITERAL
|
from trustgraph.schema import IRI, LITERAL, EntityMatch
|
||||||
|
|
||||||
|
|
||||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
mock_message.vector = [0.1, 0.2, 0.3]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'test_user'
|
mock_message.user = 'test_user'
|
||||||
mock_message.collection = 'test_collection'
|
mock_message.collection = 'test_collection'
|
||||||
|
|
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
with_payload=True
|
with_payload=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result contains expected entities
|
# Verify result contains expected EntityMatch objects
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert all(hasattr(entity, 'value') for entity in result)
|
assert all(isinstance(entity, EntityMatch) for entity in result)
|
||||||
entity_values = [entity.value for entity in result]
|
entity_values = [entity.entity.value for entity in result]
|
||||||
assert 'entity1' in entity_values
|
assert 'entity1' in entity_values
|
||||||
assert 'entity2' in entity_values
|
assert 'entity2' in entity_values
|
||||||
|
|
||||||
|
|
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Create mock message with multiple vectors
|
# Create mock message with single vector
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 3
|
mock_message.limit = 3
|
||||||
mock_message.user = 'multi_user'
|
mock_message.user = 'multi_user'
|
||||||
mock_message.collection = 'multi_collection'
|
mock_message.collection = 'multi_collection'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await processor.query_graph_embeddings(mock_message)
|
result = await processor.query_graph_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify query was called twice
|
# Verify query was called once
|
||||||
assert mock_qdrant_instance.query_points.call_count == 2
|
assert mock_qdrant_instance.query_points.call_count == 1
|
||||||
|
|
||||||
# Verify both collections were queried (both 2-dimensional vectors)
|
# Verify collection was queried
|
||||||
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
|
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
|
||||||
calls = mock_qdrant_instance.query_points.call_args_list
|
calls = mock_qdrant_instance.query_points.call_args_list
|
||||||
assert calls[0][1]['collection_name'] == expected_collection
|
assert calls[0][1]['collection_name'] == expected_collection
|
||||||
assert calls[1][1]['collection_name'] == expected_collection
|
|
||||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
|
||||||
|
# Verify results with EntityMatch structure
|
||||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
entity_values = [entity.entity.value for entity in result]
|
||||||
entity_values = [entity.value for entity in result]
|
|
||||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||||
assert 'entity1' in entity_values
|
assert 'entity1' in entity_values
|
||||||
assert 'entity2' in entity_values
|
assert 'entity2' in entity_values
|
||||||
assert 'entity3' in entity_values
|
|
||||||
|
|
||||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message with limit
|
# Create mock message with limit
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
mock_message.vector = [0.1, 0.2, 0.3]
|
||||||
mock_message.limit = 3 # Should only return 3 results
|
mock_message.limit = 3 # Should only return 3 results
|
||||||
mock_message.user = 'limit_user'
|
mock_message.user = 'limit_user'
|
||||||
mock_message.collection = 'limit_collection'
|
mock_message.collection = 'limit_collection'
|
||||||
|
|
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'empty_user'
|
mock_message.user = 'empty_user'
|
||||||
mock_message.collection = 'empty_collection'
|
mock_message.collection = 'empty_collection'
|
||||||
|
|
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Create mock message with different dimension vectors
|
# Create mock message with single vector
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
mock_message.vector = [0.1, 0.2] # 2D vector
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'dim_user'
|
mock_message.user = 'dim_user'
|
||||||
mock_message.collection = 'dim_collection'
|
mock_message.collection = 'dim_collection'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = await processor.query_graph_embeddings(mock_message)
|
result = await processor.query_graph_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify query was called twice with different collections
|
# Verify query was called once
|
||||||
assert mock_qdrant_instance.query_points.call_count == 2
|
assert mock_qdrant_instance.query_points.call_count == 1
|
||||||
calls = mock_qdrant_instance.query_points.call_args_list
|
calls = mock_qdrant_instance.query_points.call_args_list
|
||||||
|
|
||||||
# First call should use 2D collection
|
# Call should use 2D collection
|
||||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
|
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
|
||||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||||
|
|
||||||
# Second call should use 3D collection
|
# Verify results with EntityMatch structure
|
||||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
|
entity_values = [entity.entity.value for entity in result]
|
||||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
entity_values = [entity.value for entity in result]
|
|
||||||
assert 'entity2d' in entity_values
|
assert 'entity2d' in entity_values
|
||||||
assert 'entity3d' in entity_values
|
|
||||||
|
|
||||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'uri_user'
|
mock_message.user = 'uri_user'
|
||||||
mock_message.collection = 'uri_collection'
|
mock_message.collection = 'uri_collection'
|
||||||
|
|
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
|
|
||||||
# Check URI entities
|
# Check URI entities
|
||||||
uri_entities = [entity for entity in result if entity.type == IRI]
|
uri_entities = [entity for entity in result if entity.entity.type == IRI]
|
||||||
assert len(uri_entities) == 2
|
assert len(uri_entities) == 2
|
||||||
uri_values = [entity.iri for entity in uri_entities]
|
uri_values = [entity.entity.iri for entity in uri_entities]
|
||||||
assert 'http://example.com/entity1' in uri_values
|
assert 'http://example.com/entity1' in uri_values
|
||||||
assert 'https://secure.example.com/entity2' in uri_values
|
assert 'https://secure.example.com/entity2' in uri_values
|
||||||
|
|
||||||
# Check regular entities
|
# Check regular entities
|
||||||
regular_entities = [entity for entity in result if entity.type == LITERAL]
|
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
|
||||||
assert len(regular_entities) == 1
|
assert len(regular_entities) == 1
|
||||||
assert regular_entities[0].value == 'regular entity'
|
assert regular_entities[0].entity.value == 'regular entity'
|
||||||
|
|
||||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||||
|
|
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 5
|
mock_message.limit = 5
|
||||||
mock_message.user = 'error_user'
|
mock_message.user = 'error_user'
|
||||||
mock_message.collection = 'error_collection'
|
mock_message.collection = 'error_collection'
|
||||||
|
|
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Create mock message with zero limit
|
# Create mock message with zero limit
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.vectors = [[0.1, 0.2]]
|
mock_message.vector = [0.1, 0.2]
|
||||||
mock_message.limit = 0
|
mock_message.limit = 0
|
||||||
mock_message.user = 'zero_user'
|
mock_message.user = 'zero_user'
|
||||||
mock_message.collection = 'zero_collection'
|
mock_message.collection = 'zero_collection'
|
||||||
|
|
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||||
# With zero limit, the logic still adds one entity before checking the limit
|
# With zero limit, the logic still adds one entity before checking the limit
|
||||||
# So it returns one result (current behavior, not ideal but actual)
|
# So it returns one result (current behavior, not ideal but actual)
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].value == 'entity1'
|
assert result[0].entity.value == 'entity1'
|
||||||
|
|
||||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||||
|
|
|
||||||
|
|
@ -175,9 +175,14 @@ class TestQuery:
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
|
|
||||||
# Mock document embeddings returns chunk_ids
|
# Mock document embeddings returns ChunkMatch objects
|
||||||
test_chunk_ids = ["doc/c1", "doc/c2"]
|
mock_match1 = MagicMock()
|
||||||
mock_doc_embeddings_client.query.return_value = test_chunk_ids
|
mock_match1.chunk_id = "doc/c1"
|
||||||
|
mock_match1.score = 0.95
|
||||||
|
mock_match2 = MagicMock()
|
||||||
|
mock_match2.chunk_id = "doc/c2"
|
||||||
|
mock_match2.score = 0.85
|
||||||
|
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -195,9 +200,9 @@ class TestQuery:
|
||||||
# Verify embeddings client was called (now expects list)
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify doc embeddings client was called correctly (with extracted vectors)
|
# Verify doc embeddings client was called correctly (with extracted vector)
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
test_vectors,
|
vector=test_vectors,
|
||||||
limit=15,
|
limit=15,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
|
|
@ -218,11 +223,16 @@ class TestQuery:
|
||||||
# Mock embeddings and document embeddings responses
|
# Mock embeddings and document embeddings responses
|
||||||
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
# New batch format: [[[vectors]]] - get_vector extracts [0]
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
test_chunk_ids = ["doc/c3", "doc/c4"]
|
mock_match1 = MagicMock()
|
||||||
|
mock_match1.chunk_id = "doc/c3"
|
||||||
|
mock_match1.score = 0.9
|
||||||
|
mock_match2 = MagicMock()
|
||||||
|
mock_match2.chunk_id = "doc/c4"
|
||||||
|
mock_match2.score = 0.8
|
||||||
expected_response = "This is the document RAG response"
|
expected_response = "This is the document RAG response"
|
||||||
|
|
||||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
mock_doc_embeddings_client.query.return_value = test_chunk_ids
|
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||||
mock_prompt_client.document_prompt.return_value = expected_response
|
mock_prompt_client.document_prompt.return_value = expected_response
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
|
|
@ -245,9 +255,9 @@ class TestQuery:
|
||||||
# Verify embeddings client was called (now expects list)
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with(["test query"])
|
mock_embeddings_client.embed.assert_called_once_with(["test query"])
|
||||||
|
|
||||||
# Verify doc embeddings client was called (with extracted vectors)
|
# Verify doc embeddings client was called (with extracted vector)
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
test_vectors,
|
vector=test_vectors,
|
||||||
limit=10,
|
limit=10,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
|
|
@ -275,7 +285,10 @@ class TestQuery:
|
||||||
|
|
||||||
# Mock responses (batch format)
|
# Mock responses (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["doc/c5"]
|
mock_match = MagicMock()
|
||||||
|
mock_match.chunk_id = "doc/c5"
|
||||||
|
mock_match.score = 0.9
|
||||||
|
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
|
|
@ -289,9 +302,9 @@ class TestQuery:
|
||||||
# Call DocumentRag.query with minimal parameters
|
# Call DocumentRag.query with minimal parameters
|
||||||
result = await document_rag.query("simple query")
|
result = await document_rag.query("simple query")
|
||||||
|
|
||||||
# Verify default parameters were used (vectors extracted from batch)
|
# Verify default parameters were used (vector extracted from batch)
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
[[0.1, 0.2]],
|
vector=[[0.1, 0.2]],
|
||||||
limit=20, # Default doc_limit
|
limit=20, # Default doc_limit
|
||||||
user="trustgraph", # Default user
|
user="trustgraph", # Default user
|
||||||
collection="default" # Default collection
|
collection="default" # Default collection
|
||||||
|
|
@ -316,7 +329,10 @@ class TestQuery:
|
||||||
|
|
||||||
# Mock responses (batch format)
|
# Mock responses (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
|
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["doc/c6"]
|
mock_match = MagicMock()
|
||||||
|
mock_match.chunk_id = "doc/c6"
|
||||||
|
mock_match.score = 0.88
|
||||||
|
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||||
|
|
||||||
# Initialize Query with verbose=True
|
# Initialize Query with verbose=True
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -347,7 +363,10 @@ class TestQuery:
|
||||||
|
|
||||||
# Mock responses (batch format)
|
# Mock responses (batch format)
|
||||||
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||||
mock_doc_embeddings_client.query.return_value = ["doc/c7"]
|
mock_match = MagicMock()
|
||||||
|
mock_match.chunk_id = "doc/c7"
|
||||||
|
mock_match.score = 0.92
|
||||||
|
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||||
|
|
||||||
# Initialize DocumentRag with verbose=True
|
# Initialize DocumentRag with verbose=True
|
||||||
|
|
@ -487,7 +506,13 @@ class TestQuery:
|
||||||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||||
|
|
||||||
mock_embeddings_client.embed.return_value = [query_vectors]
|
mock_embeddings_client.embed.return_value = [query_vectors]
|
||||||
mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids
|
mock_matches = []
|
||||||
|
for chunk_id in retrieved_chunk_ids:
|
||||||
|
mock_match = MagicMock()
|
||||||
|
mock_match.chunk_id = chunk_id
|
||||||
|
mock_match.score = 0.9
|
||||||
|
mock_matches.append(mock_match)
|
||||||
|
mock_doc_embeddings_client.query.return_value = mock_matches
|
||||||
mock_prompt_client.document_prompt.return_value = final_response
|
mock_prompt_client.document_prompt.return_value = final_response
|
||||||
|
|
||||||
# Initialize DocumentRag
|
# Initialize DocumentRag
|
||||||
|
|
@ -511,7 +536,7 @@ class TestQuery:
|
||||||
mock_embeddings_client.embed.assert_called_once_with([query_text])
|
mock_embeddings_client.embed.assert_called_once_with([query_text])
|
||||||
|
|
||||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||||
query_vectors,
|
vector=query_vectors,
|
||||||
limit=25,
|
limit=25,
|
||||||
user="research_user",
|
user="research_user",
|
||||||
collection="ml_knowledge"
|
collection="ml_knowledge"
|
||||||
|
|
|
||||||
|
|
@ -193,12 +193,20 @@ class TestQuery:
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
|
|
||||||
# Mock entity objects that have string representation
|
# Mock EntityMatch objects with entity that has string representation
|
||||||
mock_entity1 = MagicMock()
|
mock_entity1 = MagicMock()
|
||||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
||||||
|
mock_match1 = MagicMock()
|
||||||
|
mock_match1.entity = mock_entity1
|
||||||
|
mock_match1.score = 0.95
|
||||||
|
|
||||||
mock_entity2 = MagicMock()
|
mock_entity2 = MagicMock()
|
||||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
||||||
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
|
mock_match2 = MagicMock()
|
||||||
|
mock_match2.entity = mock_entity2
|
||||||
|
mock_match2.score = 0.85
|
||||||
|
|
||||||
|
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -216,9 +224,9 @@ class TestQuery:
|
||||||
# Verify embeddings client was called (now expects list)
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify graph embeddings client was called correctly (with extracted vectors)
|
# Verify graph embeddings client was called correctly (with extracted vector)
|
||||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||||
vectors=test_vectors,
|
vector=test_vectors,
|
||||||
limit=25,
|
limit=25,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
# Create test document embeddings
|
# Create test document embeddings
|
||||||
chunk1 = ChunkEmbeddings(
|
chunk1 = ChunkEmbeddings(
|
||||||
chunk_id="This is the first document chunk",
|
chunk_id="This is the first document chunk",
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
chunk2 = ChunkEmbeddings(
|
chunk2 = ChunkEmbeddings(
|
||||||
chunk_id="This is the second document chunk",
|
chunk_id="This is the second document chunk",
|
||||||
vectors=[[0.7, 0.8, 0.9]]
|
vector=[0.7, 0.8, 0.9]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk1, chunk2]
|
message.chunks = [chunk1, chunk2]
|
||||||
|
|
||||||
|
|
@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id="Test document content",
|
chunk_id="Test document content",
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
await processor.store_document_embeddings(message)
|
await processor.store_document_embeddings(message)
|
||||||
|
|
||||||
# Verify insert was called for each vector with user/collection parameters
|
# Verify insert was called once for the single chunk with its vector
|
||||||
expected_calls = [
|
processor.vecstore.insert.assert_called_once_with(
|
||||||
([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'),
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'
|
||||||
([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'),
|
)
|
||||||
]
|
|
||||||
|
|
||||||
assert processor.vecstore.insert.call_count == 2
|
|
||||||
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
|
||||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
|
||||||
assert actual_call[0][0] == expected_vec
|
|
||||||
assert actual_call[0][1] == expected_doc
|
|
||||||
assert actual_call[0][2] == expected_user
|
|
||||||
assert actual_call[0][3] == expected_collection
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message):
|
||||||
"""Test storing document embeddings for multiple chunks"""
|
"""Test storing document embeddings for multiple chunks"""
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Verify insert was called for each vector of each chunk with user/collection parameters
|
# Verify insert was called once per chunk with user/collection parameters
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
# Chunk 1 vectors
|
# Chunk 1 - single vector
|
||||||
([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'),
|
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
|
||||||
([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'),
|
# Chunk 2 - single vector
|
||||||
# Chunk 2 vectors
|
|
||||||
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
|
([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert processor.vecstore.insert.call_count == 3
|
assert processor.vecstore.insert.call_count == 2
|
||||||
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
||||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||||
assert actual_call[0][0] == expected_vec
|
assert actual_call[0][0] == expected_vec
|
||||||
|
|
@ -137,7 +127,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id="",
|
chunk_id="",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -156,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id=None,
|
chunk_id=None,
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -177,15 +167,15 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
valid_chunk = ChunkEmbeddings(
|
valid_chunk = ChunkEmbeddings(
|
||||||
chunk_id="Valid document content",
|
chunk_id="Valid document content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
empty_chunk = ChunkEmbeddings(
|
empty_chunk = ChunkEmbeddings(
|
||||||
chunk_id="",
|
chunk_id="",
|
||||||
vectors=[[0.4, 0.5, 0.6]]
|
vector=[0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
another_valid = ChunkEmbeddings(
|
another_valid = ChunkEmbeddings(
|
||||||
chunk_id="Another valid chunk",
|
chunk_id="Another valid chunk",
|
||||||
vectors=[[0.7, 0.8, 0.9]]
|
vector=[0.7, 0.8, 0.9]
|
||||||
)
|
)
|
||||||
message.chunks = [valid_chunk, empty_chunk, another_valid]
|
message.chunks = [valid_chunk, empty_chunk, another_valid]
|
||||||
|
|
||||||
|
|
@ -229,7 +219,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id="Document with no vectors",
|
chunk_id="Document with no vectors",
|
||||||
vectors=[]
|
vector=[]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -245,26 +235,31 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
# Each chunk has a single vector of different dimensions
|
||||||
chunk_id="Document with mixed dimensions",
|
chunk1 = ChunkEmbeddings(
|
||||||
vectors=[
|
chunk_id="chunk/doc/2d",
|
||||||
[0.1, 0.2], # 2D vector
|
vector=[0.1, 0.2] # 2D vector
|
||||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
|
||||||
[0.7, 0.8, 0.9] # 3D vector
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
chunk2 = ChunkEmbeddings(
|
||||||
|
chunk_id="chunk/doc/4d",
|
||||||
|
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||||
|
)
|
||||||
|
chunk3 = ChunkEmbeddings(
|
||||||
|
chunk_id="chunk/doc/3d",
|
||||||
|
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||||
|
)
|
||||||
|
message.chunks = [chunk1, chunk2, chunk3]
|
||||||
|
|
||||||
await processor.store_document_embeddings(message)
|
await processor.store_document_embeddings(message)
|
||||||
|
|
||||||
# Verify all vectors were inserted regardless of dimension with user/collection parameters
|
# Verify all vectors were inserted regardless of dimension with user/collection parameters
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'),
|
([0.1, 0.2], "chunk/doc/2d", 'test_user', 'test_collection'),
|
||||||
([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'),
|
([0.3, 0.4, 0.5, 0.6], "chunk/doc/4d", 'test_user', 'test_collection'),
|
||||||
([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'),
|
([0.7, 0.8, 0.9], "chunk/doc/3d", 'test_user', 'test_collection'),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert processor.vecstore.insert.call_count == 3
|
assert processor.vecstore.insert.call_count == 3
|
||||||
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls):
|
||||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||||
|
|
@ -283,7 +278,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id="chunk/doc/unicode-éñ中文🚀",
|
chunk_id="chunk/doc/unicode-éñ中文🚀",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -306,7 +301,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
long_chunk_id = "chunk/doc/" + "a" * 200
|
long_chunk_id = "chunk/doc/" + "a" * 200
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id=long_chunk_id,
|
chunk_id=long_chunk_id,
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -327,7 +322,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id=" \n\t ",
|
chunk_id=" \n\t ",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -358,7 +353,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id="Test content",
|
chunk_id="Test content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -379,7 +374,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message1.metadata.collection = 'collection1'
|
message1.metadata.collection = 'collection1'
|
||||||
chunk1 = ChunkEmbeddings(
|
chunk1 = ChunkEmbeddings(
|
||||||
chunk_id="User1 content",
|
chunk_id="User1 content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message1.chunks = [chunk1]
|
message1.chunks = [chunk1]
|
||||||
|
|
||||||
|
|
@ -390,7 +385,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
message2.metadata.collection = 'collection2'
|
message2.metadata.collection = 'collection2'
|
||||||
chunk2 = ChunkEmbeddings(
|
chunk2 = ChunkEmbeddings(
|
||||||
chunk_id="User2 content",
|
chunk_id="User2 content",
|
||||||
vectors=[[0.4, 0.5, 0.6]]
|
vector=[0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
message2.chunks = [chunk2]
|
message2.chunks = [chunk2]
|
||||||
|
|
||||||
|
|
@ -421,7 +416,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk_id="Special chars test",
|
chunk_id="Special chars test",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,11 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
# Create test document embeddings
|
# Create test document embeddings
|
||||||
chunk1 = ChunkEmbeddings(
|
chunk1 = ChunkEmbeddings(
|
||||||
chunk=b"This is the first document chunk",
|
chunk=b"This is the first document chunk",
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
chunk2 = ChunkEmbeddings(
|
chunk2 = ChunkEmbeddings(
|
||||||
chunk=b"This is the second document chunk",
|
chunk=b"This is the second document chunk",
|
||||||
vectors=[[0.7, 0.8, 0.9]]
|
vector=[0.7, 0.8, 0.9]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk1, chunk2]
|
message.chunks = [chunk1, chunk2]
|
||||||
|
|
||||||
|
|
@ -125,7 +125,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Test document content",
|
chunk=b"Test document content",
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -190,7 +190,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Test document content",
|
chunk=b"Test document content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -222,7 +222,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"",
|
chunk=b"",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -244,7 +244,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=None,
|
chunk=None,
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -266,7 +266,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"", # Empty bytes
|
chunk=b"", # Empty bytes
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -286,37 +286,39 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
# Each chunk has a single vector of different dimensions
|
||||||
chunk=b"Document with mixed dimensions",
|
chunk1 = ChunkEmbeddings(
|
||||||
vectors=[
|
chunk=b"Document chunk 1",
|
||||||
[0.1, 0.2], # 2D vector
|
vector=[0.1, 0.2] # 2D vector
|
||||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
|
||||||
[0.7, 0.8, 0.9] # 3D vector
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
chunk2 = ChunkEmbeddings(
|
||||||
|
chunk=b"Document chunk 2",
|
||||||
mock_index_2d = MagicMock()
|
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||||
mock_index_4d = MagicMock()
|
)
|
||||||
mock_index_3d = MagicMock()
|
chunk3 = ChunkEmbeddings(
|
||||||
|
chunk=b"Document chunk 3",
|
||||||
|
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||||
|
)
|
||||||
|
message.chunks = [chunk1, chunk2, chunk3]
|
||||||
|
|
||||||
|
mock_index = MagicMock()
|
||||||
|
|
||||||
def mock_index_side_effect(name):
|
def mock_index_side_effect(name):
|
||||||
# All dimensions now use the same index name pattern
|
# All dimensions now use the same index name pattern
|
||||||
# Different dimensions will be handled within the same index
|
|
||||||
if "test_user" in name and "test_collection" in name:
|
if "test_user" in name and "test_collection" in name:
|
||||||
return mock_index_2d # Just return one mock for all
|
return mock_index
|
||||||
return MagicMock()
|
return MagicMock()
|
||||||
|
|
||||||
processor.pinecone.Index.side_effect = mock_index_side_effect
|
processor.pinecone.Index.side_effect = mock_index_side_effect
|
||||||
processor.pinecone.has_index.return_value = True
|
processor.pinecone.has_index.return_value = True
|
||||||
|
|
||||||
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
|
||||||
await processor.store_document_embeddings(message)
|
await processor.store_document_embeddings(message)
|
||||||
|
|
||||||
# Verify all vectors are now stored in the same index
|
# Verify all vectors are now stored in the same index
|
||||||
# (Pinecone can handle mixed dimensions in the same index)
|
# (Each chunk has a single vector, called once per chunk)
|
||||||
assert processor.pinecone.Index.call_count == 3 # Called once per vector
|
assert processor.pinecone.Index.call_count == 3 # Called once per chunk
|
||||||
mock_index_2d.upsert.call_count == 3 # All upserts go to same index
|
assert mock_index.upsert.call_count == 3 # All upserts go to same index
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
async def test_store_document_embeddings_empty_chunks_list(self, processor):
|
||||||
|
|
@ -346,7 +348,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Document with no vectors",
|
chunk=b"Document with no vectors",
|
||||||
vectors=[]
|
vector=[]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -368,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Test document content",
|
chunk=b"Test document content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -393,7 +395,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=b"Test document content",
|
chunk=b"Test document content",
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -419,7 +421,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
chunk="Document with Unicode: éñ中文🚀".encode('utf-8'),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
@ -447,7 +449,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
||||||
large_content = "A" * 10000 # 10KB of content
|
large_content = "A" * 10000 # 10KB of content
|
||||||
chunk = ChunkEmbeddings(
|
chunk = ChunkEmbeddings(
|
||||||
chunk=large_content.encode('utf-8'),
|
chunk=large_content.encode('utf-8'),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.chunks = [chunk]
|
message.chunks = [chunk]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
|
mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes
|
||||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
mock_chunk.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
||||||
|
|
@ -143,11 +143,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk1 = MagicMock()
|
mock_chunk1 = MagicMock()
|
||||||
mock_chunk1.chunk_id = 'doc/c1'
|
mock_chunk1.chunk_id = 'doc/c1'
|
||||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
mock_chunk1.vector = [0.1, 0.2]
|
||||||
|
|
||||||
mock_chunk2 = MagicMock()
|
mock_chunk2 = MagicMock()
|
||||||
mock_chunk2.chunk_id = 'doc/c2'
|
mock_chunk2.chunk_id = 'doc/c2'
|
||||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
mock_chunk2.vector = [0.3, 0.4]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||||
|
|
||||||
|
|
@ -175,8 +175,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client):
|
async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
|
||||||
"""Test storing document embeddings with multiple vectors per chunk"""
|
"""Test storing document embeddings with multiple chunks"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
|
|
@ -196,41 +196,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Add collection to known_collections (simulates config push)
|
# Add collection to known_collections (simulates config push)
|
||||||
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with chunk having multiple vectors
|
# Create mock message with multiple chunks, each having a single vector
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'vector_user'
|
mock_message.metadata.user = 'vector_user'
|
||||||
mock_message.metadata.collection = 'vector_collection'
|
mock_message.metadata.collection = 'vector_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk1 = MagicMock()
|
||||||
mock_chunk.chunk_id = 'doc/multi-vector'
|
mock_chunk1.chunk_id = 'doc/c1'
|
||||||
mock_chunk.vectors = [
|
mock_chunk1.vector = [0.1, 0.2, 0.3]
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
[0.4, 0.5, 0.6],
|
|
||||||
[0.7, 0.8, 0.9]
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_chunk2 = MagicMock()
|
||||||
|
mock_chunk2.chunk_id = 'doc/c2'
|
||||||
|
mock_chunk2.vector = [0.4, 0.5, 0.6]
|
||||||
|
|
||||||
|
mock_chunk3 = MagicMock()
|
||||||
|
mock_chunk3.chunk_id = 'doc/c3'
|
||||||
|
mock_chunk3.vector = [0.7, 0.8, 0.9]
|
||||||
|
|
||||||
|
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Should be called 3 times (once per vector)
|
# Should be called 3 times (once per chunk)
|
||||||
assert mock_qdrant_instance.upsert.call_count == 3
|
assert mock_qdrant_instance.upsert.call_count == 3
|
||||||
|
|
||||||
# Verify all vectors were processed
|
# Verify all vectors were processed
|
||||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||||
|
|
||||||
expected_vectors = [
|
expected_data = [
|
||||||
[0.1, 0.2, 0.3],
|
([0.1, 0.2, 0.3], 'doc/c1'),
|
||||||
[0.4, 0.5, 0.6],
|
([0.4, 0.5, 0.6], 'doc/c2'),
|
||||||
[0.7, 0.8, 0.9]
|
([0.7, 0.8, 0.9], 'doc/c3')
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, call in enumerate(upsert_calls):
|
for i, call in enumerate(upsert_calls):
|
||||||
point = call[1]['points'][0]
|
point = call[1]['points'][0]
|
||||||
assert point.vector == expected_vectors[i]
|
assert point.vector == expected_data[i][0]
|
||||||
assert point.payload['chunk_id'] == 'doc/multi-vector'
|
assert point.payload['chunk_id'] == expected_data[i][1]
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
|
async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client):
|
||||||
|
|
@ -256,7 +260,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk_empty = MagicMock()
|
mock_chunk_empty = MagicMock()
|
||||||
mock_chunk_empty.chunk_id = "" # Empty chunk_id
|
mock_chunk_empty.chunk_id = "" # Empty chunk_id
|
||||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
mock_chunk_empty.vector = [0.1, 0.2]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk_empty]
|
mock_message.chunks = [mock_chunk_empty]
|
||||||
|
|
||||||
|
|
@ -299,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk_id = 'doc/test-chunk'
|
mock_chunk.chunk_id = 'doc/test-chunk'
|
||||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
||||||
|
|
@ -351,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk_id = 'doc/test-chunk'
|
mock_chunk.chunk_id = 'doc/test-chunk'
|
||||||
mock_chunk.vectors = [[0.1, 0.2]]
|
mock_chunk.vector = [0.1, 0.2]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
||||||
|
|
@ -389,7 +393,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk1 = MagicMock()
|
mock_chunk1 = MagicMock()
|
||||||
mock_chunk1.chunk_id = 'doc/c1'
|
mock_chunk1.chunk_id = 'doc/c1'
|
||||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
mock_chunk1.vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
mock_message1.chunks = [mock_chunk1]
|
mock_message1.chunks = [mock_chunk1]
|
||||||
|
|
||||||
|
|
@ -407,7 +411,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk2 = MagicMock()
|
mock_chunk2 = MagicMock()
|
||||||
mock_chunk2.chunk_id = 'doc/c2'
|
mock_chunk2.chunk_id = 'doc/c2'
|
||||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
mock_chunk2.vector = [0.4, 0.5, 0.6] # Same dimension (3)
|
||||||
|
|
||||||
mock_message2.chunks = [mock_chunk2]
|
mock_message2.chunks = [mock_chunk2]
|
||||||
|
|
||||||
|
|
@ -446,19 +450,20 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Add collection to known_collections (simulates config push)
|
# Add collection to known_collections (simulates config push)
|
||||||
processor.known_collections[('dim_user', 'dim_collection')] = {}
|
processor.known_collections[('dim_user', 'dim_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with different dimension vectors
|
# Create mock message with chunks of different dimensions
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'dim_user'
|
mock_message.metadata.user = 'dim_user'
|
||||||
mock_message.metadata.collection = 'dim_collection'
|
mock_message.metadata.collection = 'dim_collection'
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk1 = MagicMock()
|
||||||
mock_chunk.chunk_id = 'doc/dim-test'
|
mock_chunk1.chunk_id = 'doc/c1'
|
||||||
mock_chunk.vectors = [
|
mock_chunk1.vector = [0.1, 0.2] # 2 dimensions
|
||||||
[0.1, 0.2], # 2 dimensions
|
|
||||||
[0.3, 0.4, 0.5] # 3 dimensions
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_chunk2 = MagicMock()
|
||||||
|
mock_chunk2.chunk_id = 'doc/c2'
|
||||||
|
mock_chunk2.vector = [0.3, 0.4, 0.5] # 3 dimensions
|
||||||
|
|
||||||
|
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
@ -526,7 +531,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_chunk = MagicMock()
|
mock_chunk = MagicMock()
|
||||||
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
|
mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3'
|
||||||
mock_chunk.vectors = [[0.1, 0.2]]
|
mock_chunk.vector = [0.1, 0.2]
|
||||||
|
|
||||||
mock_message.chunks = [mock_chunk]
|
mock_message.chunks = [mock_chunk]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,11 +23,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
# Create test entities with embeddings
|
# Create test entities with embeddings
|
||||||
entity1 = EntityEmbeddings(
|
entity1 = EntityEmbeddings(
|
||||||
entity=Term(type=IRI, iri='http://example.com/entity1'),
|
entity=Term(type=IRI, iri='http://example.com/entity1'),
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
entity2 = EntityEmbeddings(
|
entity2 = EntityEmbeddings(
|
||||||
entity=Term(type=LITERAL, value='literal entity'),
|
entity=Term(type=LITERAL, value='literal entity'),
|
||||||
vectors=[[0.7, 0.8, 0.9]]
|
vector=[0.7, 0.8, 0.9]
|
||||||
)
|
)
|
||||||
message.entities = [entity1, entity2]
|
message.entities = [entity1, entity2]
|
||||||
|
|
||||||
|
|
@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
await processor.store_graph_embeddings(message)
|
await processor.store_graph_embeddings(message)
|
||||||
|
|
||||||
# Verify insert was called for each vector with user/collection parameters
|
# Verify insert was called once with the full vector
|
||||||
expected_calls = [
|
processor.vecstore.insert.assert_called_once()
|
||||||
([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'),
|
actual_call = processor.vecstore.insert.call_args_list[0]
|
||||||
([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'),
|
assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
]
|
assert actual_call[0][1] == 'http://example.com/entity'
|
||||||
|
assert actual_call[0][2] == 'test_user'
|
||||||
assert processor.vecstore.insert.call_count == 2
|
assert actual_call[0][3] == 'test_collection'
|
||||||
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
|
|
||||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
|
||||||
assert actual_call[0][0] == expected_vec
|
|
||||||
assert actual_call[0][1] == expected_entity
|
|
||||||
assert actual_call[0][2] == expected_user
|
|
||||||
assert actual_call[0][3] == expected_collection
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message):
|
||||||
"""Test storing graph embeddings for multiple entities"""
|
"""Test storing graph embeddings for multiple entities"""
|
||||||
await processor.store_graph_embeddings(mock_message)
|
await processor.store_graph_embeddings(mock_message)
|
||||||
|
|
||||||
# Verify insert was called for each vector of each entity with user/collection parameters
|
# Verify insert was called once per entity with user/collection parameters
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
# Entity 1 vectors
|
# Entity 1 - single vector
|
||||||
([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
||||||
([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'),
|
# Entity 2 - single vector
|
||||||
# Entity 2 vectors
|
|
||||||
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
|
([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert processor.vecstore.insert.call_count == 3
|
assert processor.vecstore.insert.call_count == 2
|
||||||
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
|
for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls):
|
||||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||||
assert actual_call[0][0] == expected_vec
|
assert actual_call[0][0] == expected_vec
|
||||||
|
|
@ -137,7 +130,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Term(type=LITERAL, value=''),
|
entity=Term(type=LITERAL, value=''),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -156,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Term(type=LITERAL, value=None),
|
entity=Term(type=LITERAL, value=None),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -175,17 +168,17 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
valid_entity = EntityEmbeddings(
|
valid_entity = EntityEmbeddings(
|
||||||
entity=Term(type=IRI, iri='http://example.com/valid'),
|
entity=Term(type=IRI, iri='http://example.com/valid'),
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
vector=[0.1, 0.2, 0.3],
|
||||||
chunk_id=''
|
chunk_id=''
|
||||||
)
|
)
|
||||||
empty_entity = EntityEmbeddings(
|
empty_entity = EntityEmbeddings(
|
||||||
entity=Term(type=LITERAL, value=''),
|
entity=Term(type=LITERAL, value=''),
|
||||||
vectors=[[0.4, 0.5, 0.6]],
|
vector=[0.4, 0.5, 0.6],
|
||||||
chunk_id=''
|
chunk_id=''
|
||||||
)
|
)
|
||||||
none_entity = EntityEmbeddings(
|
none_entity = EntityEmbeddings(
|
||||||
entity=Term(type=LITERAL, value=None),
|
entity=Term(type=LITERAL, value=None),
|
||||||
vectors=[[0.7, 0.8, 0.9]],
|
vector=[0.7, 0.8, 0.9],
|
||||||
chunk_id=''
|
chunk_id=''
|
||||||
)
|
)
|
||||||
message.entities = [valid_entity, empty_entity, none_entity]
|
message.entities = [valid_entity, empty_entity, none_entity]
|
||||||
|
|
@ -222,7 +215,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||||
vectors=[]
|
vector=[]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -238,26 +231,31 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
# Each entity has a single vector of different dimensions
|
||||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
entity1 = EntityEmbeddings(
|
||||||
vectors=[
|
entity=Term(type=IRI, iri='http://example.com/entity1'),
|
||||||
[0.1, 0.2], # 2D vector
|
vector=[0.1, 0.2] # 2D vector
|
||||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
|
||||||
[0.7, 0.8, 0.9] # 3D vector
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
entity2 = EntityEmbeddings(
|
||||||
|
entity=Term(type=IRI, iri='http://example.com/entity2'),
|
||||||
|
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||||
|
)
|
||||||
|
entity3 = EntityEmbeddings(
|
||||||
|
entity=Term(type=IRI, iri='http://example.com/entity3'),
|
||||||
|
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||||
|
)
|
||||||
|
message.entities = [entity1, entity2, entity3]
|
||||||
|
|
||||||
await processor.store_graph_embeddings(message)
|
await processor.store_graph_embeddings(message)
|
||||||
|
|
||||||
# Verify all vectors were inserted regardless of dimension
|
# Verify all vectors were inserted regardless of dimension
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
([0.1, 0.2], 'http://example.com/entity'),
|
([0.1, 0.2], 'http://example.com/entity1'),
|
||||||
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'),
|
([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity2'),
|
||||||
([0.7, 0.8, 0.9], 'http://example.com/entity'),
|
([0.7, 0.8, 0.9], 'http://example.com/entity3'),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert processor.vecstore.insert.call_count == 3
|
assert processor.vecstore.insert.call_count == 3
|
||||||
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
for i, (expected_vec, expected_entity) in enumerate(expected_calls):
|
||||||
actual_call = processor.vecstore.insert.call_args_list[i]
|
actual_call = processor.vecstore.insert.call_args_list[i]
|
||||||
|
|
@ -274,11 +272,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
uri_entity = EntityEmbeddings(
|
uri_entity = EntityEmbeddings(
|
||||||
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
|
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
literal_entity = EntityEmbeddings(
|
literal_entity = EntityEmbeddings(
|
||||||
entity=Term(type=LITERAL, value='literal entity text'),
|
entity=Term(type=LITERAL, value='literal entity text'),
|
||||||
vectors=[[0.4, 0.5, 0.6]]
|
vector=[0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
message.entities = [uri_entity, literal_entity]
|
message.entities = [uri_entity, literal_entity]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,16 +24,20 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
# Create test entity embeddings
|
# Create test entity embeddings (each entity has a single vector)
|
||||||
entity1 = EntityEmbeddings(
|
entity1 = EntityEmbeddings(
|
||||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
entity2 = EntityEmbeddings(
|
entity2 = EntityEmbeddings(
|
||||||
entity=Value(value="entity2", is_uri=False),
|
entity=Value(value="http://example.org/entity2", is_uri=True),
|
||||||
vectors=[[0.7, 0.8, 0.9]]
|
vector=[0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
message.entities = [entity1, entity2]
|
entity3 = EntityEmbeddings(
|
||||||
|
entity=Value(value="entity3", is_uri=False),
|
||||||
|
vector=[0.7, 0.8, 0.9]
|
||||||
|
)
|
||||||
|
message.entities = [entity1, entity2, entity3]
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
@ -122,27 +126,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Value(value="http://example.org/entity1", is_uri=True),
|
entity=Value(value="http://example.org/entity1", is_uri=True),
|
||||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
# Mock index operations
|
# Mock index operations
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
processor.pinecone.Index.return_value = mock_index
|
processor.pinecone.Index.return_value = mock_index
|
||||||
processor.pinecone.has_index.return_value = True
|
processor.pinecone.has_index.return_value = True
|
||||||
|
|
||||||
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
|
with patch('uuid.uuid4', side_effect=['id1']):
|
||||||
await processor.store_graph_embeddings(message)
|
await processor.store_graph_embeddings(message)
|
||||||
|
|
||||||
# Verify index name and operations (with dimension suffix)
|
# Verify index name and operations (with dimension suffix)
|
||||||
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
|
||||||
processor.pinecone.Index.assert_called_with(expected_index_name)
|
processor.pinecone.Index.assert_called_with(expected_index_name)
|
||||||
|
|
||||||
# Verify upsert was called for each vector
|
# Verify upsert was called for the single vector
|
||||||
assert mock_index.upsert.call_count == 2
|
assert mock_index.upsert.call_count == 1
|
||||||
|
|
||||||
# Check first vector upsert
|
# Check first vector upsert
|
||||||
first_call = mock_index.upsert.call_args_list[0]
|
first_call = mock_index.upsert.call_args_list[0]
|
||||||
|
|
@ -190,7 +194,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Value(value="test_entity", is_uri=False),
|
entity=Value(value="test_entity", is_uri=False),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -222,7 +226,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Value(value="", is_uri=False),
|
entity=Value(value="", is_uri=False),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -244,7 +248,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Value(value=None, is_uri=False),
|
entity=Value(value=None, is_uri=False),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -258,23 +262,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
|
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
|
||||||
"""Test storing graph embeddings with different vector dimensions to same index"""
|
"""Test storing graph embeddings with different vector dimensions"""
|
||||||
message = MagicMock()
|
message = MagicMock()
|
||||||
message.metadata = MagicMock()
|
message.metadata = MagicMock()
|
||||||
message.metadata.user = 'test_user'
|
message.metadata.user = 'test_user'
|
||||||
message.metadata.collection = 'test_collection'
|
message.metadata.collection = 'test_collection'
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
# Each entity has a single vector of different dimensions
|
||||||
entity=Value(value="test_entity", is_uri=False),
|
entity1 = EntityEmbeddings(
|
||||||
vectors=[
|
entity=Value(value="entity1", is_uri=False),
|
||||||
[0.1, 0.2], # 2D vector
|
vector=[0.1, 0.2] # 2D vector
|
||||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
|
||||||
[0.7, 0.8, 0.9] # 3D vector
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
entity2 = EntityEmbeddings(
|
||||||
|
entity=Value(value="entity2", is_uri=False),
|
||||||
|
vector=[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||||
|
)
|
||||||
|
entity3 = EntityEmbeddings(
|
||||||
|
entity=Value(value="entity3", is_uri=False),
|
||||||
|
vector=[0.7, 0.8, 0.9] # 3D vector
|
||||||
|
)
|
||||||
|
message.entities = [entity1, entity2, entity3]
|
||||||
|
|
||||||
# All vectors now use the same index (no dimension in name)
|
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
processor.pinecone.Index.return_value = mock_index
|
processor.pinecone.Index.return_value = mock_index
|
||||||
processor.pinecone.has_index.return_value = True
|
processor.pinecone.has_index.return_value = True
|
||||||
|
|
@ -322,7 +330,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Value(value="test_entity", is_uri=False),
|
entity=Value(value="test_entity", is_uri=False),
|
||||||
vectors=[]
|
vector=[]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -344,7 +352,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Value(value="test_entity", is_uri=False),
|
entity=Value(value="test_entity", is_uri=False),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
@ -369,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
||||||
|
|
||||||
entity = EntityEmbeddings(
|
entity = EntityEmbeddings(
|
||||||
entity=Value(value="test_entity", is_uri=False),
|
entity=Value(value="test_entity", is_uri=False),
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
message.entities = [entity]
|
message.entities = [entity]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_entity = MagicMock()
|
mock_entity = MagicMock()
|
||||||
mock_entity.entity.type = IRI
|
mock_entity.entity.type = IRI
|
||||||
mock_entity.entity.iri = 'test_entity'
|
mock_entity.entity.iri = 'test_entity'
|
||||||
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
mock_entity.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions
|
||||||
|
|
||||||
mock_message.entities = [mock_entity]
|
mock_message.entities = [mock_entity]
|
||||||
|
|
||||||
|
|
@ -124,12 +124,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_entity1 = MagicMock()
|
mock_entity1 = MagicMock()
|
||||||
mock_entity1.entity.type = IRI
|
mock_entity1.entity.type = IRI
|
||||||
mock_entity1.entity.iri = 'entity_one'
|
mock_entity1.entity.iri = 'entity_one'
|
||||||
mock_entity1.vectors = [[0.1, 0.2]]
|
mock_entity1.vector = [0.1, 0.2]
|
||||||
|
|
||||||
mock_entity2 = MagicMock()
|
mock_entity2 = MagicMock()
|
||||||
mock_entity2.entity.type = IRI
|
mock_entity2.entity.type = IRI
|
||||||
mock_entity2.entity.iri = 'entity_two'
|
mock_entity2.entity.iri = 'entity_two'
|
||||||
mock_entity2.vectors = [[0.3, 0.4]]
|
mock_entity2.vector = [0.3, 0.4]
|
||||||
|
|
||||||
mock_message.entities = [mock_entity1, mock_entity2]
|
mock_message.entities = [mock_entity1, mock_entity2]
|
||||||
|
|
||||||
|
|
@ -157,14 +157,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client):
|
async def test_store_graph_embeddings_three_entities(self, mock_uuid, mock_qdrant_client):
|
||||||
"""Test storing graph embeddings with multiple vectors per entity"""
|
"""Test storing graph embeddings with three entities"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'store_uri': 'http://localhost:6333',
|
'store_uri': 'http://localhost:6333',
|
||||||
'api_key': 'test-api-key',
|
'api_key': 'test-api-key',
|
||||||
|
|
@ -177,42 +177,48 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
# Add collection to known_collections (simulates config push)
|
# Add collection to known_collections (simulates config push)
|
||||||
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with entity having multiple vectors
|
# Create mock message with three entities
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'vector_user'
|
mock_message.metadata.user = 'vector_user'
|
||||||
mock_message.metadata.collection = 'vector_collection'
|
mock_message.metadata.collection = 'vector_collection'
|
||||||
|
|
||||||
mock_entity = MagicMock()
|
mock_entity1 = MagicMock()
|
||||||
mock_entity.entity.type = IRI
|
mock_entity1.entity.type = IRI
|
||||||
mock_entity.entity.iri = 'multi_vector_entity'
|
mock_entity1.entity.iri = 'entity_one'
|
||||||
mock_entity.vectors = [
|
mock_entity1.vector = [0.1, 0.2, 0.3]
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
[0.4, 0.5, 0.6],
|
mock_entity2 = MagicMock()
|
||||||
[0.7, 0.8, 0.9]
|
mock_entity2.entity.type = IRI
|
||||||
]
|
mock_entity2.entity.iri = 'entity_two'
|
||||||
|
mock_entity2.vector = [0.4, 0.5, 0.6]
|
||||||
mock_message.entities = [mock_entity]
|
|
||||||
|
mock_entity3 = MagicMock()
|
||||||
|
mock_entity3.entity.type = IRI
|
||||||
|
mock_entity3.entity.iri = 'entity_three'
|
||||||
|
mock_entity3.vector = [0.7, 0.8, 0.9]
|
||||||
|
|
||||||
|
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.store_graph_embeddings(mock_message)
|
await processor.store_graph_embeddings(mock_message)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Should be called 3 times (once per vector)
|
# Should be called 3 times (once per entity)
|
||||||
assert mock_qdrant_instance.upsert.call_count == 3
|
assert mock_qdrant_instance.upsert.call_count == 3
|
||||||
|
|
||||||
# Verify all vectors were processed
|
# Verify all entities were processed
|
||||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||||
|
|
||||||
expected_vectors = [
|
expected_data = [
|
||||||
[0.1, 0.2, 0.3],
|
([0.1, 0.2, 0.3], 'entity_one'),
|
||||||
[0.4, 0.5, 0.6],
|
([0.4, 0.5, 0.6], 'entity_two'),
|
||||||
[0.7, 0.8, 0.9]
|
([0.7, 0.8, 0.9], 'entity_three')
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, call in enumerate(upsert_calls):
|
for i, call in enumerate(upsert_calls):
|
||||||
point = call[1]['points'][0]
|
point = call[1]['points'][0]
|
||||||
assert point.vector == expected_vectors[i]
|
assert point.vector == expected_data[i][0]
|
||||||
assert point.payload['entity'] == 'multi_vector_entity'
|
assert point.payload['entity'] == expected_data[i][1]
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
|
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
|
||||||
|
|
@ -238,11 +244,11 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_entity_empty = MagicMock()
|
mock_entity_empty = MagicMock()
|
||||||
mock_entity_empty.entity.type = LITERAL
|
mock_entity_empty.entity.type = LITERAL
|
||||||
mock_entity_empty.entity.value = "" # Empty string
|
mock_entity_empty.entity.value = "" # Empty string
|
||||||
mock_entity_empty.vectors = [[0.1, 0.2]]
|
mock_entity_empty.vector = [0.1, 0.2]
|
||||||
|
|
||||||
mock_entity_none = MagicMock()
|
mock_entity_none = MagicMock()
|
||||||
mock_entity_none.entity = None # None entity
|
mock_entity_none.entity = None # None entity
|
||||||
mock_entity_none.vectors = [[0.3, 0.4]]
|
mock_entity_none.vector = [0.3, 0.4]
|
||||||
|
|
||||||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -197,7 +197,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
index_name='customer_id',
|
index_name='customer_id',
|
||||||
index_value=['CUST001'],
|
index_value=['CUST001'],
|
||||||
text='CUST001',
|
text='CUST001',
|
||||||
vectors=[[0.1, 0.2, 0.3]]
|
vector=[0.1, 0.2, 0.3]
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_msg = RowEmbeddings(
|
embeddings_msg = RowEmbeddings(
|
||||||
|
|
@ -227,8 +227,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||||
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
|
async def test_on_embeddings_single_vector(self, mock_uuid, mock_qdrant_client):
|
||||||
"""Test processing embeddings with multiple vectors"""
|
"""Test processing embeddings with a single vector"""
|
||||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||||
|
|
||||||
|
|
@ -250,12 +250,12 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
metadata.collection = 'test_collection'
|
metadata.collection = 'test_collection'
|
||||||
metadata.id = 'doc-123'
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
# Embedding with multiple vectors
|
# Embedding with a single 6D vector
|
||||||
embedding = RowIndexEmbedding(
|
embedding = RowIndexEmbedding(
|
||||||
index_name='name',
|
index_name='name',
|
||||||
index_value=['John Doe'],
|
index_value=['John Doe'],
|
||||||
text='John Doe',
|
text='John Doe',
|
||||||
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_msg = RowEmbeddings(
|
embeddings_msg = RowEmbeddings(
|
||||||
|
|
@ -269,8 +269,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||||
|
|
||||||
# Should be called 3 times (once per vector)
|
# Should be called once for the single embedding
|
||||||
assert mock_qdrant_instance.upsert.call_count == 3
|
assert mock_qdrant_instance.upsert.call_count == 1
|
||||||
|
|
||||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
|
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
|
||||||
|
|
@ -299,7 +299,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
index_name='id',
|
index_name='id',
|
||||||
index_value=['123'],
|
index_value=['123'],
|
||||||
text='123',
|
text='123',
|
||||||
vectors=[] # Empty vectors
|
vector=[] # Empty vector
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_msg = RowEmbeddings(
|
embeddings_msg = RowEmbeddings(
|
||||||
|
|
@ -342,7 +342,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
index_name='id',
|
index_name='id',
|
||||||
index_value=['123'],
|
index_value=['123'],
|
||||||
text='123',
|
text='123',
|
||||||
vectors=[[0.1, 0.2]]
|
vector=[0.1, 0.2]
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_msg = RowEmbeddings(
|
embeddings_msg = RowEmbeddings(
|
||||||
|
|
|
||||||
|
|
@ -612,12 +612,12 @@ class AsyncFlowInstance:
|
||||||
print(f"{entity['name']}: {entity['score']}")
|
print(f"{entity['name']}: {entity['score']}")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -810,12 +810,12 @@ class AsyncFlowInstance:
|
||||||
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
|
||||||
|
|
||||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||||
"""Query graph embeddings for semantic search"""
|
"""Query graph embeddings for semantic search"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
|
||||||
limit: int = 10, **kwargs
|
limit: int = 10, **kwargs
|
||||||
):
|
):
|
||||||
"""Query row embeddings for semantic search on structured data"""
|
"""Query row embeddings for semantic search on structured data"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -602,13 +602,13 @@ class FlowInstance:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
# Query graph embeddings for semantic search
|
# Query graph embeddings for semantic search
|
||||||
input = {
|
input = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -648,13 +648,13 @@ class FlowInstance:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
# Query document embeddings for semantic search
|
# Query document embeddings for semantic search
|
||||||
input = {
|
input = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -1362,13 +1362,13 @@ class FlowInstance:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
# Query row embeddings for semantic search
|
# Query row embeddings for semantic search
|
||||||
input = {
|
input = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -649,12 +649,12 @@ class SocketFlowInstance:
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -698,12 +698,12 @@ class SocketFlowInstance:
|
||||||
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
|
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -936,12 +936,12 @@ class SocketFlowInstance:
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DocumentEmbeddingsClient(RequestResponse):
|
class DocumentEmbeddingsClient(RequestResponse):
|
||||||
async def query(self, vectors, limit=20, user="trustgraph",
|
async def query(self, vector, limit=20, user="trustgraph",
|
||||||
collection="default", timeout=30):
|
collection="default", timeout=30):
|
||||||
|
|
||||||
resp = await self.request(
|
resp = await self.request(
|
||||||
DocumentEmbeddingsRequest(
|
DocumentEmbeddingsRequest(
|
||||||
vectors = vectors,
|
vector = vector,
|
||||||
limit = limit,
|
limit = limit,
|
||||||
user = user,
|
user = user,
|
||||||
collection = collection
|
collection = collection
|
||||||
|
|
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
|
||||||
if resp.error:
|
if resp.error:
|
||||||
raise RuntimeError(resp.error.message)
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
return resp.chunk_ids
|
# Return ChunkMatch objects with chunk_id and score
|
||||||
|
return resp.chunks
|
||||||
|
|
||||||
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||||
docs = await self.query_document_embeddings(request)
|
docs = await self.query_document_embeddings(request)
|
||||||
|
|
||||||
logger.debug("Sending document embeddings query response...")
|
logger.debug("Sending document embeddings query response...")
|
||||||
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
|
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
||||||
logger.debug("Document embeddings query request completed")
|
logger.debug("Document embeddings query request completed")
|
||||||
|
|
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||||
type = "document-embeddings-query-error",
|
type = "document-embeddings-query-error",
|
||||||
message = str(e),
|
message = str(e),
|
||||||
),
|
),
|
||||||
chunk_ids=[],
|
chunks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,12 @@ def to_value(x):
|
||||||
return Literal(x.value or x.iri)
|
return Literal(x.value or x.iri)
|
||||||
|
|
||||||
class GraphEmbeddingsClient(RequestResponse):
|
class GraphEmbeddingsClient(RequestResponse):
|
||||||
async def query(self, vectors, limit=20, user="trustgraph",
|
async def query(self, vector, limit=20, user="trustgraph",
|
||||||
collection="default", timeout=30):
|
collection="default", timeout=30):
|
||||||
|
|
||||||
resp = await self.request(
|
resp = await self.request(
|
||||||
GraphEmbeddingsRequest(
|
GraphEmbeddingsRequest(
|
||||||
vectors = vectors,
|
vector = vector,
|
||||||
limit = limit,
|
limit = limit,
|
||||||
user = user,
|
user = user,
|
||||||
collection = collection
|
collection = collection
|
||||||
|
|
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
|
||||||
if resp.error:
|
if resp.error:
|
||||||
raise RuntimeError(resp.error.message)
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
return [
|
# Return EntityMatch objects with entity and score
|
||||||
to_value(v)
|
return resp.entities
|
||||||
for v in resp.entities
|
|
||||||
]
|
|
||||||
|
|
||||||
class GraphEmbeddingsClientSpec(RequestResponseSpec):
|
class GraphEmbeddingsClientSpec(RequestResponseSpec):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||||
|
|
||||||
class RowEmbeddingsQueryClient(RequestResponse):
|
class RowEmbeddingsQueryClient(RequestResponse):
|
||||||
async def row_embeddings_query(
|
async def row_embeddings_query(
|
||||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
self, vector, schema_name, user="trustgraph", collection="default",
|
||||||
index_name=None, limit=10, timeout=600
|
index_name=None, limit=10, timeout=600
|
||||||
):
|
):
|
||||||
request = RowEmbeddingsRequest(
|
request = RowEmbeddingsRequest(
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
schema_name=schema_name,
|
schema_name=schema_name,
|
||||||
user=user,
|
user=user,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
self, vectors, user="trustgraph", collection="default",
|
self, vector, user="trustgraph", collection="default",
|
||||||
limit=10, timeout=300
|
limit=10, timeout=300
|
||||||
):
|
):
|
||||||
return self.call(
|
return self.call(
|
||||||
user=user, collection=collection,
|
user=user, collection=collection,
|
||||||
vectors=vectors, limit=limit, timeout=timeout
|
vector=vector, limit=limit, timeout=timeout
|
||||||
).chunks
|
).chunks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
self, vectors, user="trustgraph", collection="default",
|
self, vector, user="trustgraph", collection="default",
|
||||||
limit=10, timeout=300
|
limit=10, timeout=300
|
||||||
):
|
):
|
||||||
return self.call(
|
return self.call(
|
||||||
user=user, collection=collection,
|
user=user, collection=collection,
|
||||||
vectors=vectors, limit=limit, timeout=timeout
|
vector=vector, limit=limit, timeout=timeout
|
||||||
).entities
|
).entities
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
self, vector, schema_name, user="trustgraph", collection="default",
|
||||||
index_name=None, limit=10, timeout=300
|
index_name=None, limit=10, timeout=300
|
||||||
):
|
):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
user=user, collection=collection,
|
user=user, collection=collection,
|
||||||
vectors=vectors, schema_name=schema_name,
|
vector=vector, schema_name=schema_name,
|
||||||
limit=limit, timeout=timeout
|
limit=limit, timeout=timeout
|
||||||
)
|
)
|
||||||
if index_name:
|
if index_name:
|
||||||
|
|
|
||||||
|
|
@ -10,18 +10,18 @@ from .primitives import ValueTranslator
|
||||||
|
|
||||||
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
||||||
return DocumentEmbeddingsRequest(
|
return DocumentEmbeddingsRequest(
|
||||||
vectors=data["vectors"],
|
vector=data["vector"],
|
||||||
limit=int(data.get("limit", 10)),
|
limit=int(data.get("limit", 10)),
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default")
|
collection=data.get("collection", "default")
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"vectors": obj.vectors,
|
"vector": obj.vector,
|
||||||
"limit": obj.limit,
|
"limit": obj.limit,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection
|
"collection": obj.collection
|
||||||
|
|
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
||||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if obj.chunk_ids is not None:
|
if obj.chunks is not None:
|
||||||
result["chunk_ids"] = list(obj.chunk_ids)
|
result["chunks"] = [
|
||||||
|
{
|
||||||
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"score": chunk.score
|
||||||
|
}
|
||||||
|
for chunk in obj.chunks
|
||||||
|
]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
|
|
||||||
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
"""Translator for GraphEmbeddingsRequest schema objects"""
|
"""Translator for GraphEmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
||||||
return GraphEmbeddingsRequest(
|
return GraphEmbeddingsRequest(
|
||||||
vectors=data["vectors"],
|
vector=data["vector"],
|
||||||
limit=int(data.get("limit", 10)),
|
limit=int(data.get("limit", 10)),
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default")
|
collection=data.get("collection", "default")
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"vectors": obj.vectors,
|
"vector": obj.vector,
|
||||||
"limit": obj.limit,
|
"limit": obj.limit,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection
|
"collection": obj.collection
|
||||||
|
|
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
"""Translator for GraphEmbeddingsResponse schema objects"""
|
"""Translator for GraphEmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.value_translator = ValueTranslator()
|
self.value_translator = ValueTranslator()
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
||||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if obj.entities is not None:
|
if obj.entities is not None:
|
||||||
result["entities"] = [
|
result["entities"] = [
|
||||||
self.value_translator.from_pulsar(entity)
|
{
|
||||||
for entity in obj.entities
|
"entity": self.value_translator.from_pulsar(match.entity),
|
||||||
|
"score": match.score
|
||||||
|
}
|
||||||
|
for match in obj.entities
|
||||||
]
|
]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||||
return RowEmbeddingsRequest(
|
return RowEmbeddingsRequest(
|
||||||
vectors=data["vectors"],
|
vector=data["vector"],
|
||||||
limit=int(data.get("limit", 10)),
|
limit=int(data.get("limit", 10)),
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default"),
|
collection=data.get("collection", "default"),
|
||||||
|
|
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
result = {
|
result = {
|
||||||
"vectors": obj.vectors,
|
"vector": obj.vector,
|
||||||
"limit": obj.limit,
|
"limit": obj.limit,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection,
|
"collection": obj.collection,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from ..core.topic import topic
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityEmbeddings:
|
class EntityEmbeddings:
|
||||||
entity: Term | None = None
|
entity: Term | None = None
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
# Provenance: which chunk this embedding was derived from
|
# Provenance: which chunk this embedding was derived from
|
||||||
chunk_id: str = ""
|
chunk_id: str = ""
|
||||||
|
|
||||||
|
|
@ -28,7 +28,7 @@ class GraphEmbeddings:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChunkEmbeddings:
|
class ChunkEmbeddings:
|
||||||
chunk_id: str = ""
|
chunk_id: str = ""
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
|
|
||||||
# This is a 'batching' mechanism for the above data
|
# This is a 'batching' mechanism for the above data
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -44,7 +44,7 @@ class DocumentEmbeddings:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ObjectEmbeddings:
|
class ObjectEmbeddings:
|
||||||
metadata: Metadata | None = None
|
metadata: Metadata | None = None
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
name: str = ""
|
name: str = ""
|
||||||
key_name: str = ""
|
key_name: str = ""
|
||||||
id: str = ""
|
id: str = ""
|
||||||
|
|
@ -56,7 +56,7 @@ class ObjectEmbeddings:
|
||||||
@dataclass
|
@dataclass
|
||||||
class StructuredObjectEmbedding:
|
class StructuredObjectEmbedding:
|
||||||
metadata: Metadata | None = None
|
metadata: Metadata | None = None
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
schema_name: str = ""
|
schema_name: str = ""
|
||||||
object_id: str = "" # Primary key value
|
object_id: str = "" # Primary key value
|
||||||
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
||||||
|
|
@ -72,7 +72,7 @@ class RowIndexEmbedding:
|
||||||
index_name: str = "" # The indexed field name(s)
|
index_name: str = "" # The indexed field name(s)
|
||||||
index_value: list[str] = field(default_factory=list) # The field value(s)
|
index_value: list[str] = field(default_factory=list) # The field value(s)
|
||||||
text: str = "" # Text that was embedded
|
text: str = "" # Text that was embedded
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RowEmbeddings:
|
class RowEmbeddings:
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class EmbeddingsRequest:
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingsResponse:
|
class EmbeddingsResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
vectors: list[list[float]] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,15 +9,21 @@ from ..core.topic import topic
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphEmbeddingsRequest:
|
class GraphEmbeddingsRequest:
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
limit: int = 0
|
limit: int = 0
|
||||||
user: str = ""
|
user: str = ""
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntityMatch:
|
||||||
|
"""A matching entity from a semantic search with similarity score"""
|
||||||
|
entity: Term | None = None
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphEmbeddingsResponse:
|
class GraphEmbeddingsResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
entities: list[Term] = field(default_factory=list)
|
entities: list[EntityMatch] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
@ -44,15 +50,21 @@ class TriplesQueryResponse:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DocumentEmbeddingsRequest:
|
class DocumentEmbeddingsRequest:
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
limit: int = 0
|
limit: int = 0
|
||||||
user: str = ""
|
user: str = ""
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkMatch:
|
||||||
|
"""A matching chunk from a semantic search with similarity score"""
|
||||||
|
chunk_id: str = ""
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DocumentEmbeddingsResponse:
|
class DocumentEmbeddingsResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
chunk_ids: list[str] = field(default_factory=list)
|
chunks: list[ChunkMatch] = field(default_factory=list)
|
||||||
|
|
||||||
document_embeddings_request_queue = topic(
|
document_embeddings_request_queue = topic(
|
||||||
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
|
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
|
||||||
|
|
@ -76,7 +88,7 @@ class RowIndexMatch:
|
||||||
@dataclass
|
@dataclass
|
||||||
class RowEmbeddingsRequest:
|
class RowEmbeddingsRequest:
|
||||||
"""Request for row embeddings semantic search"""
|
"""Request for row embeddings semantic search"""
|
||||||
vectors: list[list[float]] = field(default_factory=list) # Query vectors
|
vector: list[float] = field(default_factory=list) # Query vector
|
||||||
limit: int = 10 # Max results to return
|
limit: int = 10 # Max results to return
|
||||||
user: str = "" # User/keyspace
|
user: str = "" # User/keyspace
|
||||||
collection: str = "" # Collection name
|
collection: str = "" # Collection name
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl:
|
||||||
|
|
||||||
query_text = arguments.get("query")
|
query_text = arguments.get("query")
|
||||||
all_vectors = await embeddings_client.embed([query_text])
|
all_vectors = await embeddings_client.embed([query_text])
|
||||||
vectors = all_vectors[0] if all_vectors else []
|
vector = all_vectors[0] if all_vectors else []
|
||||||
|
|
||||||
# Now query row embeddings
|
# Now query row embeddings
|
||||||
client = self.context("row-embeddings-query-request")
|
client = self.context("row-embeddings-query-request")
|
||||||
|
|
@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl:
|
||||||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||||
|
|
||||||
matches = await client.row_embeddings_query(
|
matches = await client.row_embeddings_query(
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
schema_name=self.schema_name,
|
schema_name=self.schema_name,
|
||||||
user=user,
|
user=user,
|
||||||
collection=self.collection or "default",
|
collection=self.collection or "default",
|
||||||
|
|
|
||||||
|
|
@ -66,13 +66,13 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# vectors[0] is the vector set for the first (only) text
|
# vectors[0] is the vector for the first (only) text
|
||||||
vectors = resp.vectors[0] if resp.vectors else []
|
vector = resp.vectors[0] if resp.vectors else []
|
||||||
|
|
||||||
embeds = [
|
embeds = [
|
||||||
ChunkEmbeddings(
|
ChunkEmbeddings(
|
||||||
chunk_id=v.document_id,
|
chunk_id=v.document_id,
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,11 +59,8 @@ class Processor(EmbeddingsService):
|
||||||
# FastEmbed processes the full batch efficiently
|
# FastEmbed processes the full batch efficiently
|
||||||
vecs = list(self.embeddings.embed(texts))
|
vecs = list(self.embeddings.embed(texts))
|
||||||
|
|
||||||
# Return list of vector sets, one per input text
|
# Return list of vectors, one per input text
|
||||||
return [
|
return [v.tolist() for v in vecs]
|
||||||
[v.tolist()]
|
|
||||||
for v in vecs
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -72,10 +72,10 @@ class Processor(FlowProcessor):
|
||||||
entities = [
|
entities = [
|
||||||
EntityEmbeddings(
|
EntityEmbeddings(
|
||||||
entity=entity.entity,
|
entity=entity.entity,
|
||||||
vectors=vectors, # Vector set for this entity
|
vector=vector,
|
||||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||||
)
|
)
|
||||||
for entity, vectors in zip(v.entities, all_vectors)
|
for entity, vector in zip(v.entities, all_vectors)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Send in batches to avoid oversized messages
|
# Send in batches to avoid oversized messages
|
||||||
|
|
|
||||||
|
|
@ -43,11 +43,8 @@ class Processor(EmbeddingsService):
|
||||||
input = texts
|
input = texts
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return list of vector sets, one per input text
|
# Return list of vectors, one per input text
|
||||||
return [
|
return list(embeds.embeddings)
|
||||||
[embedding]
|
|
||||||
for embedding in embeds.embeddings
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
||||||
|
|
||||||
# Pair results with metadata
|
# Pair results with metadata
|
||||||
for text, (index_name, index_value), vectors in zip(
|
for text, (index_name, index_value), vector in zip(
|
||||||
texts, metadata, all_vectors
|
texts, metadata, all_vectors
|
||||||
):
|
):
|
||||||
embeddings_list.append(
|
embeddings_list.append(
|
||||||
|
|
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
index_value=index_value,
|
index_value=index_value,
|
||||||
text=text,
|
text=text,
|
||||||
vectors=vectors # Vector set for this text
|
vector=vector
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ of chunk_ids
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .... direct.milvus_doc_embeddings import DocVectors
|
from .... direct.milvus_doc_embeddings import DocVectors
|
||||||
from .... schema import DocumentEmbeddingsResponse
|
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||||
from .... schema import Error
|
from .... schema import Error
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
chunk_ids = []
|
resp = self.vecstore.search(
|
||||||
|
vec,
|
||||||
|
msg.user,
|
||||||
|
msg.collection,
|
||||||
|
limit=msg.limit
|
||||||
|
)
|
||||||
|
|
||||||
for vec in msg.vectors:
|
chunks = []
|
||||||
|
for r in resp:
|
||||||
|
chunk_id = r["entity"]["chunk_id"]
|
||||||
|
# Milvus returns distance, convert to similarity score
|
||||||
|
distance = r.get("distance", 0.0)
|
||||||
|
score = 1.0 - distance if distance else 0.0
|
||||||
|
chunks.append(ChunkMatch(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
resp = self.vecstore.search(
|
return chunks
|
||||||
vec,
|
|
||||||
msg.user,
|
|
||||||
msg.collection,
|
|
||||||
limit=msg.limit
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in resp:
|
|
||||||
chunk_id = r["entity"]["chunk_id"]
|
|
||||||
chunk_ids.append(chunk_id)
|
|
||||||
|
|
||||||
return chunk_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import os
|
||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||||
|
|
||||||
|
from .... schema import ChunkMatch
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
|
|
@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
chunk_ids = []
|
dim = len(vec)
|
||||||
|
|
||||||
for vec in msg.vectors:
|
# Use dimension suffix in index name
|
||||||
|
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
|
||||||
|
|
||||||
dim = len(vec)
|
# Check if index exists - return empty if not
|
||||||
|
if not self.pinecone.has_index(index_name):
|
||||||
|
logger.info(f"Index {index_name} does not exist")
|
||||||
|
return []
|
||||||
|
|
||||||
# Use dimension suffix in index name
|
index = self.pinecone.Index(index_name)
|
||||||
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
|
|
||||||
|
|
||||||
# Check if index exists - skip if not
|
results = index.query(
|
||||||
if not self.pinecone.has_index(index_name):
|
vector=vec,
|
||||||
logger.info(f"Index {index_name} does not exist, skipping this vector")
|
top_k=msg.limit,
|
||||||
continue
|
include_values=False,
|
||||||
|
include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
chunks = []
|
||||||
|
for r in results.matches:
|
||||||
|
chunk_id = r.metadata["chunk_id"]
|
||||||
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
|
chunks.append(ChunkMatch(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
results = index.query(
|
return chunks
|
||||||
vector=vec,
|
|
||||||
top_k=msg.limit,
|
|
||||||
include_values=False,
|
|
||||||
include_metadata=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in results.matches:
|
|
||||||
chunk_id = r.metadata["chunk_id"]
|
|
||||||
chunk_ids.append(chunk_id)
|
|
||||||
|
|
||||||
return chunk_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
|
|
||||||
from .... schema import DocumentEmbeddingsResponse
|
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||||
from .... schema import Error
|
from .... schema import Error
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
chunk_ids = []
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
# Use dimension suffix in collection name
|
||||||
|
dim = len(vec)
|
||||||
|
collection = f"d_{msg.user}_{msg.collection}_{dim}"
|
||||||
|
|
||||||
# Use dimension suffix in collection name
|
# Check if collection exists - return empty if not
|
||||||
dim = len(vec)
|
if not self.collection_exists(collection):
|
||||||
collection = f"d_{msg.user}_{msg.collection}_{dim}"
|
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||||
|
return []
|
||||||
|
|
||||||
# Check if collection exists - return empty if not
|
search_result = self.qdrant.query_points(
|
||||||
if not self.collection_exists(collection):
|
collection_name=collection,
|
||||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
query=vec,
|
||||||
continue
|
limit=msg.limit,
|
||||||
|
with_payload=True,
|
||||||
|
).points
|
||||||
|
|
||||||
search_result = self.qdrant.query_points(
|
chunks = []
|
||||||
collection_name=collection,
|
for r in search_result:
|
||||||
query=vec,
|
chunk_id = r.payload["chunk_id"]
|
||||||
limit=msg.limit,
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
with_payload=True,
|
chunks.append(ChunkMatch(
|
||||||
).points
|
chunk_id=chunk_id,
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
for r in search_result:
|
return chunks
|
||||||
chunk_id = r.payload["chunk_id"]
|
|
||||||
chunk_ids.append(chunk_id)
|
|
||||||
|
|
||||||
return chunk_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ entities
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||||
from .... schema import GraphEmbeddingsResponse
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
entity_set = set()
|
vec = msg.vector
|
||||||
entities = []
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
resp = self.vecstore.search(
|
||||||
|
vec,
|
||||||
|
msg.user,
|
||||||
|
msg.collection,
|
||||||
|
limit=msg.limit * 2
|
||||||
|
)
|
||||||
|
|
||||||
resp = self.vecstore.search(
|
entity_set = set()
|
||||||
vec,
|
entities = []
|
||||||
msg.user,
|
|
||||||
msg.collection,
|
|
||||||
limit=msg.limit * 2
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in resp:
|
for r in resp:
|
||||||
ent = r["entity"]["entity"]
|
ent = r["entity"]["entity"]
|
||||||
|
# Milvus returns distance, convert to similarity score
|
||||||
# De-dupe entities
|
distance = r.get("distance", 0.0)
|
||||||
if ent not in entity_set:
|
score = 1.0 - distance if distance else 0.0
|
||||||
entity_set.add(ent)
|
|
||||||
entities.append(ent)
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# De-dupe entities, keep highest score
|
||||||
if len(entity_set) >= msg.limit: break
|
if ent not in entity_set:
|
||||||
|
entity_set.add(ent)
|
||||||
|
entities.append(EntityMatch(
|
||||||
|
entity=self.create_value(ent),
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# Keep adding entities until limit
|
||||||
if len(entity_set) >= msg.limit: break
|
if len(entities) >= msg.limit:
|
||||||
|
break
|
||||||
ents2 = []
|
|
||||||
|
|
||||||
for ent in entities:
|
|
||||||
ents2.append(self.create_value(ent))
|
|
||||||
|
|
||||||
entities = ents2
|
|
||||||
|
|
||||||
logger.debug("Send response...")
|
logger.debug("Send response...")
|
||||||
return entities
|
return entities
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import os
|
||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||||
|
|
||||||
from .... schema import GraphEmbeddingsResponse
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
dim = len(vec)
|
||||||
|
|
||||||
|
# Use dimension suffix in index name
|
||||||
|
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
|
||||||
|
|
||||||
|
# Check if index exists - return empty if not
|
||||||
|
if not self.pinecone.has_index(index_name):
|
||||||
|
logger.info(f"Index {index_name} does not exist")
|
||||||
|
return []
|
||||||
|
|
||||||
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
|
# Heuristic hack, get (2*limit), so that we have more chance
|
||||||
|
# of getting (limit) unique entities
|
||||||
|
results = index.query(
|
||||||
|
vector=vec,
|
||||||
|
top_k=msg.limit * 2,
|
||||||
|
include_values=False,
|
||||||
|
include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
entity_set = set()
|
entity_set = set()
|
||||||
entities = []
|
entities = []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
for r in results.matches:
|
||||||
|
ent = r.metadata["entity"]
|
||||||
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
|
|
||||||
dim = len(vec)
|
# De-dupe entities, keep highest score
|
||||||
|
if ent not in entity_set:
|
||||||
# Use dimension suffix in index name
|
entity_set.add(ent)
|
||||||
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
|
entities.append(EntityMatch(
|
||||||
|
entity=self.create_value(ent),
|
||||||
# Check if index exists - skip if not
|
score=score,
|
||||||
if not self.pinecone.has_index(index_name):
|
))
|
||||||
logger.info(f"Index {index_name} does not exist, skipping this vector")
|
|
||||||
continue
|
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
|
||||||
|
|
||||||
# Heuristic hack, get (2*limit), so that we have more chance
|
|
||||||
# of getting (limit) entities
|
|
||||||
results = index.query(
|
|
||||||
vector=vec,
|
|
||||||
top_k=msg.limit * 2,
|
|
||||||
include_values=False,
|
|
||||||
include_metadata=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in results.matches:
|
|
||||||
|
|
||||||
ent = r.metadata["entity"]
|
|
||||||
|
|
||||||
# De-dupe entities
|
|
||||||
if ent not in entity_set:
|
|
||||||
entity_set.add(ent)
|
|
||||||
entities.append(ent)
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
|
||||||
if len(entity_set) >= msg.limit: break
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# Keep adding entities until limit
|
||||||
if len(entity_set) >= msg.limit: break
|
if len(entities) >= msg.limit:
|
||||||
|
break
|
||||||
ents2 = []
|
|
||||||
|
|
||||||
for ent in entities:
|
|
||||||
ents2.append(self.create_value(ent))
|
|
||||||
|
|
||||||
entities = ents2
|
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
|
|
||||||
from .... schema import GraphEmbeddingsResponse
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use dimension suffix in collection name
|
||||||
|
dim = len(vec)
|
||||||
|
collection = f"t_{msg.user}_{msg.collection}_{dim}"
|
||||||
|
|
||||||
|
# Check if collection exists - return empty if not
|
||||||
|
if not self.collection_exists(collection):
|
||||||
|
logger.info(f"Collection {collection} does not exist")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Heuristic hack, get (2*limit), so that we have more chance
|
||||||
|
# of getting (limit) unique entities
|
||||||
|
search_result = self.qdrant.query_points(
|
||||||
|
collection_name=collection,
|
||||||
|
query=vec,
|
||||||
|
limit=msg.limit * 2,
|
||||||
|
with_payload=True,
|
||||||
|
).points
|
||||||
|
|
||||||
entity_set = set()
|
entity_set = set()
|
||||||
entities = []
|
entities = []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
for r in search_result:
|
||||||
|
ent = r.payload["entity"]
|
||||||
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
|
|
||||||
# Use dimension suffix in collection name
|
# De-dupe entities, keep highest score
|
||||||
dim = len(vec)
|
if ent not in entity_set:
|
||||||
collection = f"t_{msg.user}_{msg.collection}_{dim}"
|
entity_set.add(ent)
|
||||||
|
entities.append(EntityMatch(
|
||||||
# Check if collection exists - return empty if not
|
entity=self.create_value(ent),
|
||||||
if not self.collection_exists(collection):
|
score=score,
|
||||||
logger.info(f"Collection {collection} does not exist, skipping this vector")
|
))
|
||||||
continue
|
|
||||||
|
|
||||||
# Heuristic hack, get (2*limit), so that we have more chance
|
|
||||||
# of getting (limit) entities
|
|
||||||
search_result = self.qdrant.query_points(
|
|
||||||
collection_name=collection,
|
|
||||||
query=vec,
|
|
||||||
limit=msg.limit * 2,
|
|
||||||
with_payload=True,
|
|
||||||
).points
|
|
||||||
|
|
||||||
for r in search_result:
|
|
||||||
ent = r.payload["entity"]
|
|
||||||
|
|
||||||
# De-dupe entities
|
|
||||||
if ent not in entity_set:
|
|
||||||
entity_set.add(ent)
|
|
||||||
entities.append(ent)
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
|
||||||
if len(entity_set) >= msg.limit: break
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# Keep adding entities until limit
|
||||||
if len(entity_set) >= msg.limit: break
|
if len(entities) >= msg.limit:
|
||||||
|
break
|
||||||
ents2 = []
|
|
||||||
|
|
||||||
for ent in entities:
|
|
||||||
ents2.append(self.create_value(ent))
|
|
||||||
|
|
||||||
entities = ents2
|
|
||||||
|
|
||||||
logger.debug("Send response...")
|
logger.debug("Send response...")
|
||||||
return entities
|
return entities
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,9 @@ class Processor(FlowProcessor):
|
||||||
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
||||||
"""Execute row embeddings query"""
|
"""Execute row embeddings query"""
|
||||||
|
|
||||||
matches = []
|
vec = request.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Find the collection for this user/collection/schema
|
# Find the collection for this user/collection/schema
|
||||||
qdrant_collection = self.find_collection(
|
qdrant_collection = self.find_collection(
|
||||||
|
|
@ -105,47 +107,47 @@ class Processor(FlowProcessor):
|
||||||
f"No Qdrant collection found for "
|
f"No Qdrant collection found for "
|
||||||
f"{request.user}/{request.collection}/{request.schema_name}"
|
f"{request.user}/{request.collection}/{request.schema_name}"
|
||||||
)
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build optional filter for index_name
|
||||||
|
query_filter = None
|
||||||
|
if request.index_name:
|
||||||
|
query_filter = Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="index_name",
|
||||||
|
match=MatchValue(value=request.index_name)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query Qdrant
|
||||||
|
search_result = self.qdrant.query_points(
|
||||||
|
collection_name=qdrant_collection,
|
||||||
|
query=vec,
|
||||||
|
limit=request.limit,
|
||||||
|
with_payload=True,
|
||||||
|
query_filter=query_filter,
|
||||||
|
).points
|
||||||
|
|
||||||
|
# Convert to RowIndexMatch objects
|
||||||
|
matches = []
|
||||||
|
for point in search_result:
|
||||||
|
payload = point.payload or {}
|
||||||
|
match = RowIndexMatch(
|
||||||
|
index_name=payload.get("index_name", ""),
|
||||||
|
index_value=payload.get("index_value", []),
|
||||||
|
text=payload.get("text", ""),
|
||||||
|
score=point.score if hasattr(point, 'score') else 0.0
|
||||||
|
)
|
||||||
|
matches.append(match)
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
for vec in request.vectors:
|
except Exception as e:
|
||||||
try:
|
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
||||||
# Build optional filter for index_name
|
raise
|
||||||
query_filter = None
|
|
||||||
if request.index_name:
|
|
||||||
query_filter = Filter(
|
|
||||||
must=[
|
|
||||||
FieldCondition(
|
|
||||||
key="index_name",
|
|
||||||
match=MatchValue(value=request.index_name)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query Qdrant
|
|
||||||
search_result = self.qdrant.query_points(
|
|
||||||
collection_name=qdrant_collection,
|
|
||||||
query=vec,
|
|
||||||
limit=request.limit,
|
|
||||||
with_payload=True,
|
|
||||||
query_filter=query_filter,
|
|
||||||
).points
|
|
||||||
|
|
||||||
# Convert to RowIndexMatch objects
|
|
||||||
for point in search_result:
|
|
||||||
payload = point.payload or {}
|
|
||||||
match = RowIndexMatch(
|
|
||||||
index_name=payload.get("index_name", ""),
|
|
||||||
index_value=payload.get("index_value", []),
|
|
||||||
text=payload.get("text", ""),
|
|
||||||
score=point.score if hasattr(point, 'score') else 0.0
|
|
||||||
)
|
|
||||||
matches.append(match)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
async def on_message(self, msg, consumer, flow):
|
async def on_message(self, msg, consumer, flow):
|
||||||
"""Handle incoming query request"""
|
"""Handle incoming query request"""
|
||||||
|
|
|
||||||
|
|
@ -37,26 +37,26 @@ class Query:
|
||||||
vectors = await self.get_vector(query)
|
vectors = await self.get_vector(query)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Getting chunk_ids from embeddings store...")
|
logger.debug("Getting chunks from embeddings store...")
|
||||||
|
|
||||||
# Get chunk_ids from embeddings store
|
# Get chunk matches from embeddings store
|
||||||
chunk_ids = await self.rag.doc_embeddings_client.query(
|
chunk_matches = await self.rag.doc_embeddings_client.query(
|
||||||
vectors, limit=self.doc_limit,
|
vector=vectors, limit=self.doc_limit,
|
||||||
user=self.user, collection=self.collection,
|
user=self.user, collection=self.collection,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
|
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
|
||||||
|
|
||||||
# Fetch chunk content from Garage
|
# Fetch chunk content from Garage
|
||||||
docs = []
|
docs = []
|
||||||
for chunk_id in chunk_ids:
|
for match in chunk_matches:
|
||||||
if chunk_id:
|
if match.chunk_id:
|
||||||
try:
|
try:
|
||||||
content = await self.rag.fetch_chunk(chunk_id, self.user)
|
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
|
||||||
docs.append(content)
|
docs.append(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
|
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Documents fetched:")
|
logger.debug("Documents fetched:")
|
||||||
|
|
|
||||||
|
|
@ -87,14 +87,14 @@ class Query:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Getting entities...")
|
logger.debug("Getting entities...")
|
||||||
|
|
||||||
entities = await self.rag.graph_embeddings_client.query(
|
entity_matches = await self.rag.graph_embeddings_client.query(
|
||||||
vectors=vectors, limit=self.entity_limit,
|
vector=vectors, limit=self.entity_limit,
|
||||||
user=self.user, collection=self.collection,
|
user=self.user, collection=self.collection,
|
||||||
)
|
)
|
||||||
|
|
||||||
entities = [
|
entities = [
|
||||||
str(e)
|
str(e.entity)
|
||||||
for e in entities
|
for e in entity_matches
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
if chunk_id == "":
|
if chunk_id == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in emb.vectors:
|
vec = emb.vector
|
||||||
|
if vec:
|
||||||
self.vecstore.insert(
|
self.vecstore.insert(
|
||||||
vec, chunk_id,
|
vec, chunk_id,
|
||||||
message.metadata.user,
|
message.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
if chunk_id == "":
|
if chunk_id == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in emb.vectors:
|
vec = emb.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create index name with dimension suffix for lazy creation
|
# Create index name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
index_name = (
|
index_name = (
|
||||||
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||||
if not self.pinecone.has_index(index_name):
|
if not self.pinecone.has_index(index_name):
|
||||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||||
self.create_index(index_name, dim)
|
self.create_index(index_name, dim)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
# Generate unique ID for each vector
|
# Generate unique ID for each vector
|
||||||
vector_id = str(uuid.uuid4())
|
vector_id = str(uuid.uuid4())
|
||||||
|
|
||||||
records = [
|
records = [
|
||||||
{
|
{
|
||||||
"id": vector_id,
|
"id": vector_id,
|
||||||
"values": vec,
|
"values": vec,
|
||||||
"metadata": { "chunk_id": chunk_id },
|
"metadata": { "chunk_id": chunk_id },
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
index.upsert(
|
index.upsert(
|
||||||
vectors = records,
|
vectors = records,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
if chunk_id == "":
|
if chunk_id == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in emb.vectors:
|
vec = emb.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create collection name with dimension suffix for lazy creation
|
# Create collection name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = (
|
collection = (
|
||||||
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||||
if not self.qdrant.collection_exists(collection):
|
if not self.qdrant.collection_exists(collection):
|
||||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||||
self.qdrant.create_collection(
|
self.qdrant.create_collection(
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim,
|
|
||||||
distance=Distance.COSINE
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.qdrant.upsert(
|
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
points=[
|
vectors_config=VectorParams(
|
||||||
PointStruct(
|
size=dim,
|
||||||
id=str(uuid.uuid4()),
|
distance=Distance.COSINE
|
||||||
vector=vec,
|
)
|
||||||
payload={
|
|
||||||
"chunk_id": chunk_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.qdrant.upsert(
|
||||||
|
collection_name=collection,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
vector=vec,
|
||||||
|
payload={
|
||||||
|
"chunk_id": chunk_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
entity_value = get_term_value(entity.entity)
|
entity_value = get_term_value(entity.entity)
|
||||||
|
|
||||||
if entity_value != "" and entity_value is not None:
|
if entity_value != "" and entity_value is not None:
|
||||||
for vec in entity.vectors:
|
vec = entity.vector
|
||||||
|
if vec:
|
||||||
self.vecstore.insert(
|
self.vecstore.insert(
|
||||||
vec, entity_value,
|
vec, entity_value,
|
||||||
message.metadata.user,
|
message.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
if entity_value == "" or entity_value is None:
|
if entity_value == "" or entity_value is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in entity.vectors:
|
vec = entity.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create index name with dimension suffix for lazy creation
|
# Create index name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
index_name = (
|
index_name = (
|
||||||
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||||
if not self.pinecone.has_index(index_name):
|
if not self.pinecone.has_index(index_name):
|
||||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||||
self.create_index(index_name, dim)
|
self.create_index(index_name, dim)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
# Generate unique ID for each vector
|
# Generate unique ID for each vector
|
||||||
vector_id = str(uuid.uuid4())
|
vector_id = str(uuid.uuid4())
|
||||||
|
|
||||||
metadata = {"entity": entity_value}
|
metadata = {"entity": entity_value}
|
||||||
if entity.chunk_id:
|
if entity.chunk_id:
|
||||||
metadata["chunk_id"] = entity.chunk_id
|
metadata["chunk_id"] = entity.chunk_id
|
||||||
|
|
||||||
records = [
|
records = [
|
||||||
{
|
{
|
||||||
"id": vector_id,
|
"id": vector_id,
|
||||||
"values": vec,
|
"values": vec,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
index.upsert(
|
index.upsert(
|
||||||
vectors = records,
|
vectors = records,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
if entity_value == "" or entity_value is None:
|
if entity_value == "" or entity_value is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in entity.vectors:
|
vec = entity.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create collection name with dimension suffix for lazy creation
|
# Create collection name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = (
|
collection = (
|
||||||
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||||
if not self.qdrant.collection_exists(collection):
|
if not self.qdrant.collection_exists(collection):
|
||||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||||
self.qdrant.create_collection(
|
self.qdrant.create_collection(
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim,
|
|
||||||
distance=Distance.COSINE
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"entity": entity_value,
|
|
||||||
}
|
|
||||||
if entity.chunk_id:
|
|
||||||
payload["chunk_id"] = entity.chunk_id
|
|
||||||
|
|
||||||
self.qdrant.upsert(
|
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
points=[
|
vectors_config=VectorParams(
|
||||||
PointStruct(
|
size=dim,
|
||||||
id=str(uuid.uuid4()),
|
distance=Distance.COSINE
|
||||||
vector=vec,
|
)
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"entity": entity_value,
|
||||||
|
}
|
||||||
|
if entity.chunk_id:
|
||||||
|
payload["chunk_id"] = entity.chunk_id
|
||||||
|
|
||||||
|
self.qdrant.upsert(
|
||||||
|
collection_name=collection,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
vector=vec,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
qdrant_collection = None
|
qdrant_collection = None
|
||||||
|
|
||||||
for row_emb in embeddings.embeddings:
|
for row_emb in embeddings.embeddings:
|
||||||
if not row_emb.vectors:
|
vector = row_emb.vector
|
||||||
|
if not vector:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No vectors for index {row_emb.index_name} - skipping"
|
f"No vector for index {row_emb.index_name} - skipping"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Use first vector (there may be multiple from different models)
|
dimension = len(vector)
|
||||||
for vector in row_emb.vectors:
|
|
||||||
dimension = len(vector)
|
|
||||||
|
|
||||||
# Create/get collection name (lazily on first vector)
|
# Create/get collection name (lazily on first vector)
|
||||||
if qdrant_collection is None:
|
if qdrant_collection is None:
|
||||||
qdrant_collection = self.get_collection_name(
|
qdrant_collection = self.get_collection_name(
|
||||||
user, collection, schema_name, dimension
|
user, collection, schema_name, dimension
|
||||||
)
|
|
||||||
self.ensure_collection(qdrant_collection, dimension)
|
|
||||||
|
|
||||||
# Write to Qdrant
|
|
||||||
self.qdrant.upsert(
|
|
||||||
collection_name=qdrant_collection,
|
|
||||||
points=[
|
|
||||||
PointStruct(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
vector=vector,
|
|
||||||
payload={
|
|
||||||
"index_name": row_emb.index_name,
|
|
||||||
"index_value": row_emb.index_value,
|
|
||||||
"text": row_emb.text
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
embeddings_written += 1
|
self.ensure_collection(qdrant_collection, dimension)
|
||||||
|
|
||||||
|
# Write to Qdrant
|
||||||
|
self.qdrant.upsert(
|
||||||
|
collection_name=qdrant_collection,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
vector=vector,
|
||||||
|
payload={
|
||||||
|
"index_name": row_emb.index_name,
|
||||||
|
"index_value": row_emb.index_value,
|
||||||
|
"text": row_emb.text
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
embeddings_written += 1
|
||||||
|
|
||||||
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue