diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py index c35dde4b..c7d6369a 100644 --- a/tests/contract/test_document_embeddings_contract.py +++ b/tests/contract/test_document_embeddings_contract.py @@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services import pytest from unittest.mock import MagicMock -from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error +from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error from trustgraph.messaging.translators.embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator @@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract: """Test that DocumentEmbeddingsRequest has expected fields""" # Create a request request = DocumentEmbeddingsRequest( - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3], limit=10, user="test_user", collection="test_collection" ) # Verify all expected fields exist - assert hasattr(request, 'vectors') + assert hasattr(request, 'vector') assert hasattr(request, 'limit') assert hasattr(request, 'user') assert hasattr(request, 'collection') # Verify field values - assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + assert request.vector == [0.1, 0.2, 0.3] assert request.limit == 10 assert request.user == "test_user" assert request.collection == "test_collection" @@ -43,7 +43,7 @@ class TestDocumentEmbeddingsRequestContract: translator = DocumentEmbeddingsRequestTranslator() data = { - "vectors": [[0.1, 0.2], [0.3, 0.4]], + "vector": [0.1, 0.2, 0.3, 0.4], "limit": 5, "user": "custom_user", "collection": "custom_collection" @@ -52,7 +52,7 @@ class TestDocumentEmbeddingsRequestContract: result = translator.to_pulsar(data) assert isinstance(result, DocumentEmbeddingsRequest) - assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] + assert result.vector == [0.1, 0.2, 0.3, 0.4] assert result.limit == 5 assert result.user == "custom_user" assert result.collection == "custom_collection" @@ -62,14 +62,14 @@ class TestDocumentEmbeddingsRequestContract: translator = DocumentEmbeddingsRequestTranslator() data = { - "vectors": [[0.1, 0.2]] + "vector": [0.1, 0.2] # No limit, user, or collection provided } result = translator.to_pulsar(data) assert isinstance(result, DocumentEmbeddingsRequest) - assert result.vectors == [[0.1, 0.2]] + assert result.vector == [0.1, 0.2] assert result.limit == 10 # Default assert result.user == "trustgraph" # Default assert result.collection == "default" # Default @@ -79,7 +79,7 @@ class TestDocumentEmbeddingsRequestContract: translator = DocumentEmbeddingsRequestTranslator() request = DocumentEmbeddingsRequest( - vectors=[[0.5, 0.6]], + vector=[0.5, 0.6], limit=20, user="test_user", collection="test_collection" @@ -88,7 +88,7 @@ class TestDocumentEmbeddingsRequestContract: result = translator.from_pulsar(request) assert isinstance(result, dict) - assert result["vectors"] == [[0.5, 0.6]] + assert result["vector"] == [0.5, 0.6] assert result["limit"] == 20 assert result["user"] == "test_user" assert result["collection"] == "test_collection" @@ -99,19 +99,25 @@ class TestDocumentEmbeddingsResponseContract: def test_response_schema_fields(self): """Test that DocumentEmbeddingsResponse has expected fields""" - # Create a response with chunk_ids + # Create a response with chunks response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["chunk1", "chunk2", "chunk3"] + chunks=[ + ChunkMatch(chunk_id="chunk1", score=0.9), + ChunkMatch(chunk_id="chunk2", score=0.8), + ChunkMatch(chunk_id="chunk3", score=0.7) + ] ) # Verify all expected fields exist assert hasattr(response, 'error') - assert hasattr(response, 'chunk_ids') + assert hasattr(response, 'chunks') # Verify field values assert response.error is None - assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"] + assert len(response.chunks) == 3 + assert response.chunks[0].chunk_id == "chunk1" + assert response.chunks[0].score == 0.9 def test_response_schema_with_error(self): """Test response schema with error""" @@ -122,53 +128,59 @@ class TestDocumentEmbeddingsResponseContract: response = DocumentEmbeddingsResponse( error=error, - chunk_ids=[] + chunks=[] ) assert response.error == error - assert response.chunk_ids == [] + assert response.chunks == [] - def test_response_translator_from_pulsar_with_chunk_ids(self): - """Test response translator converts Pulsar schema with chunk_ids to dict""" + def test_response_translator_from_pulsar_with_chunks(self): + """Test response translator converts Pulsar schema with chunks to dict""" translator = DocumentEmbeddingsResponseTranslator() response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"] + chunks=[ + ChunkMatch(chunk_id="doc1/c1", score=0.95), + ChunkMatch(chunk_id="doc2/c2", score=0.85), + ChunkMatch(chunk_id="doc3/c3", score=0.75) + ] ) result = translator.from_pulsar(response) assert isinstance(result, dict) - assert "chunk_ids" in result - assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"] + assert "chunks" in result + assert len(result["chunks"]) == 3 + assert result["chunks"][0]["chunk_id"] == "doc1/c1" + assert result["chunks"][0]["score"] == 0.95 - def test_response_translator_from_pulsar_with_empty_chunk_ids(self): - """Test response translator handles empty chunk_ids list""" + def test_response_translator_from_pulsar_with_empty_chunks(self): + """Test response translator handles empty chunks list""" translator = DocumentEmbeddingsResponseTranslator() response = DocumentEmbeddingsResponse( error=None, - chunk_ids=[] + chunks=[] ) result = translator.from_pulsar(response) assert isinstance(result, dict) - assert "chunk_ids" in result - assert result["chunk_ids"] == [] + assert "chunks" in result + assert result["chunks"] == [] - def test_response_translator_from_pulsar_with_none_chunk_ids(self): - """Test response translator handles None chunk_ids""" + def test_response_translator_from_pulsar_with_none_chunks(self): + """Test response translator handles None chunks""" translator = DocumentEmbeddingsResponseTranslator() response = MagicMock() - response.chunk_ids = None + response.chunks = None result = translator.from_pulsar(response) assert isinstance(result, dict) - assert "chunk_ids" not in result or result.get("chunk_ids") is None + assert "chunks" not in result or result.get("chunks") is None def test_response_translator_from_response_with_completion(self): """Test response translator with completion flag""" @@ -176,14 +188,18 @@ class TestDocumentEmbeddingsResponseContract: response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["chunk1", "chunk2"] + chunks=[ + ChunkMatch(chunk_id="chunk1", score=0.9), + ChunkMatch(chunk_id="chunk2", score=0.8) + ] ) result, is_final = translator.from_response_with_completion(response) assert isinstance(result, dict) - assert "chunk_ids" in result - assert result["chunk_ids"] == ["chunk1", "chunk2"] + assert "chunks" in result + assert len(result["chunks"]) == 2 + assert result["chunks"][0]["chunk_id"] == "chunk1" assert is_final is True # Document embeddings responses are always final def test_response_translator_to_pulsar_not_implemented(self): @@ -191,7 +207,7 @@ class TestDocumentEmbeddingsResponseContract: translator = DocumentEmbeddingsResponseTranslator() with pytest.raises(NotImplementedError): - translator.to_pulsar({"chunk_ids": ["test"]}) + translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]}) class TestDocumentEmbeddingsMessageCompatibility: @@ -201,7 +217,7 @@ class TestDocumentEmbeddingsMessageCompatibility: """Test complete request-response flow maintains data integrity""" # Create request request_data = { - "vectors": [[0.1, 0.2, 0.3]], + "vector": [0.1, 0.2, 0.3], "limit": 5, "user": "test_user", "collection": "test_collection" @@ -214,7 +230,10 @@ class TestDocumentEmbeddingsMessageCompatibility: # Simulate service processing and creating response response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["doc1/c1", "doc2/c2"] + chunks=[ + ChunkMatch(chunk_id="doc1/c1", score=0.95), + ChunkMatch(chunk_id="doc2/c2", score=0.85) + ] ) # Convert response back to dict @@ -224,8 +243,8 @@ class TestDocumentEmbeddingsMessageCompatibility: # Verify data integrity assert isinstance(pulsar_request, DocumentEmbeddingsRequest) assert isinstance(response_data, dict) - assert "chunk_ids" in response_data - assert len(response_data["chunk_ids"]) == 2 + assert "chunks" in response_data + assert len(response_data["chunks"]) == 2 def test_error_response_flow(self): """Test error response flow""" @@ -237,7 +256,7 @@ class TestDocumentEmbeddingsMessageCompatibility: response = DocumentEmbeddingsResponse( error=error, - chunk_ids=[] + chunks=[] ) # Convert response to dict @@ -246,6 +265,6 @@ class TestDocumentEmbeddingsMessageCompatibility: # Verify error handling assert isinstance(response_data, dict) - # The translator doesn't include error in the dict, only chunk_ids - assert "chunk_ids" in response_data - assert response_data["chunk_ids"] == [] + # The translator doesn't include error in the dict, only chunks + assert "chunks" in response_data + assert response_data["chunks"] == [] diff --git a/tests/contract/test_structured_data_contracts.py b/tests/contract/test_structured_data_contracts.py index 71ccd787..97197f13 100644 --- a/tests/contract/test_structured_data_contracts.py +++ b/tests/contract/test_structured_data_contracts.py @@ -285,11 +285,11 @@ class TestStructuredEmbeddingsContracts: collection="test_collection", metadata=[] ) - + # Act embedding = StructuredObjectEmbedding( metadata=metadata, - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3], schema_name="customer_records", object_id="customer_123", field_embeddings={ @@ -301,7 +301,7 @@ class TestStructuredEmbeddingsContracts: # Assert assert embedding.schema_name == "customer_records" assert embedding.object_id == "customer_123" - assert len(embedding.vectors) == 2 + assert len(embedding.vector) == 3 assert len(embedding.field_embeddings) == 2 assert "name" in embedding.field_embeddings diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index 99e25ed5..e9df05cf 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -9,6 +9,7 @@ Following the TEST_STRATEGY.md approach for integration testing. import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import ChunkMatch # Sample chunk content for testing - maps chunk_id to content @@ -39,10 +40,14 @@ class TestDocumentRagIntegration: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns chunk IDs""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() - # Now returns chunk_ids instead of actual content - client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] + # Returns ChunkMatch objects with chunk_id and score + client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90), + ChunkMatch(chunk_id="doc/c3", score=0.85) + ] return client @pytest.fixture @@ -97,7 +102,7 @@ class TestDocumentRagIntegration: mock_embeddings_client.embed.assert_called_once_with([query]) mock_doc_embeddings_client.query.assert_called_once_with( - [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], + vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]], limit=doc_limit, user=user, collection=collection @@ -298,7 +303,7 @@ class TestDocumentRagIntegration: assert "DocumentRag initialized" in log_messages assert "Constructing prompt..." in log_messages assert "Computing embeddings..." in log_messages - assert "chunk_ids" in log_messages.lower() + assert "chunks" in log_messages.lower() assert "Invoking LLM..." in log_messages assert "Query processing complete" in log_messages @@ -307,9 +312,9 @@ class TestDocumentRagIntegration: async def test_document_rag_performance_with_large_document_set(self, document_rag, mock_doc_embeddings_client): """Test DocumentRAG performance with large document retrieval""" - # Arrange - Mock large chunk_id set (100 chunks) - large_chunk_ids = [f"doc/c{i}" for i in range(100)] - mock_doc_embeddings_client.query.return_value = large_chunk_ids + # Arrange - Mock large chunk match set (100 chunks) + large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)] + mock_doc_embeddings_client.query.return_value = large_chunk_matches # Act import time diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index db79ebac..dad30a8f 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -8,6 +8,7 @@ response delivery through the complete pipeline. import pytest from unittest.mock import AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import ChunkMatch from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_callback_invoked, @@ -36,10 +37,14 @@ class TestDocumentRagStreaming: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns chunk IDs""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() - # Now returns chunk_ids instead of actual content - client.query.return_value = ["doc/c1", "doc/c2", "doc/c3"] + # Returns ChunkMatch objects with chunk_id and score + client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90), + ChunkMatch(chunk_id="doc/c3", score=0.85) + ] return client @pytest.fixture diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 94e8cf08..25a572c0 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.schema import EntityMatch, Term, IRI @pytest.mark.integration @@ -35,9 +36,9 @@ class TestGraphRagIntegration: """Mock graph embeddings client that returns realistic entities""" client = AsyncMock() client.query.return_value = [ - "http://trustgraph.ai/e/machine-learning", - "http://trustgraph.ai/e/artificial-intelligence", - "http://trustgraph.ai/e/neural-networks" + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95), + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90), + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85) ] return client @@ -130,7 +131,7 @@ class TestGraphRagIntegration: # 2. Should query graph embeddings to find relevant entities mock_graph_embeddings_client.query.assert_called_once() call_args = mock_graph_embeddings_client.query.call_args - assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] + assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]] assert call_args.kwargs['limit'] == entity_limit assert call_args.kwargs['user'] == user assert call_args.kwargs['collection'] == collection diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index f4d8ce8b..99880510 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -8,6 +8,7 @@ response delivery through the complete pipeline. import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.schema import EntityMatch, Term, IRI from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_rag_streaming_chunks, @@ -33,7 +34,7 @@ class TestGraphRagStreaming: """Mock graph embeddings client""" client = AsyncMock() client.query.return_value = [ - "http://trustgraph.ai/e/machine-learning", + EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95), ] return client diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index 2baa1d4d..c390c139 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -411,7 +411,7 @@ class TestKnowledgeGraphPipelineIntegration: entities=[ EntityEmbeddings( entity=Term(type=IRI, iri="http://example.org/entity"), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) ] ) diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 19f2cf35..4fa93afd 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, call from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI class TestGraphRagStreamingProtocol: @@ -25,7 +26,10 @@ class TestGraphRagStreamingProtocol: def mock_graph_embeddings_client(self): """Mock graph embeddings client""" client = AsyncMock() - client.query.return_value = ["entity1", "entity2"] + client.query.return_value = [ + EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95), + EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90) + ] return client @pytest.fixture @@ -202,9 +206,12 @@ class TestDocumentRagStreamingProtocol: @pytest.fixture def mock_doc_embeddings_client(self): - """Mock document embeddings client that returns chunk IDs""" + """Mock document embeddings client that returns chunk matches""" client = AsyncMock() - client.query.return_value = ["doc/c1", "doc/c2"] + client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.95), + ChunkMatch(chunk_id="doc/c2", score=0.90) + ] return client @pytest.fixture diff --git a/tests/unit/test_base/test_document_embeddings_client.py b/tests/unit/test_base/test_document_embeddings_client.py index 81d4a98e..705f2bd1 100644 --- a/tests/unit/test_base/test_document_embeddings_client.py +++ b/tests/unit/test_base/test_document_embeddings_client.py @@ -22,28 +22,28 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["chunk1", "chunk2", "chunk3"] - + mock_response.chunks = ["chunk1", "chunk2", "chunk3"] + # Mock the request method client.request = AsyncMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - + + vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + # Act result = await client.query( - vectors=vectors, + vector=vector, limit=10, user="test_user", collection="test_collection", timeout=30 ) - + # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.request.assert_called_once() call_args = client.request.call_args[0][0] assert isinstance(call_args, DocumentEmbeddingsRequest) - assert call_args.vectors == vectors + assert call_args.vector == vector assert call_args.limit == 10 assert call_args.user == "test_user" assert call_args.collection == "test_collection" @@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): # Act & Assert with pytest.raises(RuntimeError, match="Database connection failed"): await client.query( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -75,13 +75,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = [] - + mock_response.chunks = [] + client.request = AsyncMock(return_value=mock_response) - + # Act - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) - + result = await client.query(vector=[0.1, 0.2, 0.3]) + # Assert assert result == [] @@ -93,12 +93,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["test_chunk"] - + mock_response.chunks = ["test_chunk"] + client.request = AsyncMock(return_value=mock_response) - + # Act - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) + result = await client.query(vector=[0.1, 0.2, 0.3]) # Assert client.request.assert_called_once() @@ -115,16 +115,16 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["chunk1"] - + mock_response.chunks = ["chunk1"] + client.request = AsyncMock(return_value=mock_response) - + # Act await client.query( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], timeout=60 ) - + # Assert assert client.request.call_args[1]["timeout"] == 60 @@ -136,14 +136,14 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): client = DocumentEmbeddingsClient() mock_response = MagicMock(spec=DocumentEmbeddingsResponse) mock_response.error = None - mock_response.chunk_ids = ["test_chunk"] - + mock_response.chunks = ["test_chunk"] + client.request = AsyncMock(return_value=mock_response) - + # Act with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger: - result = await client.query(vectors=[[0.1, 0.2, 0.3]]) - + result = await client.query(vector=[0.1, 0.2, 0.3]) + # Assert mock_logger.debug.assert_called_once() assert "Document embeddings response" in str(mock_logger.debug.call_args) diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py index 5873d81c..ce758f66 100644 --- a/tests/unit/test_clients/test_sync_document_embeddings_client.py +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["chunk1", "chunk2", "chunk3"] client.call = MagicMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - + + vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + # Act result = client.request( - vectors=vectors, + vector=vector, user="test_user", collection="test_collection", limit=10, timeout=300 ) - + # Assert assert result == ["chunk1", "chunk2", "chunk3"] client.call.assert_called_once_with( user="test_user", collection="test_collection", - vectors=vectors, + vector=vector, limit=10, timeout=300 ) @@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["test_chunk"] client.call = MagicMock(return_value=mock_response) - - vectors = [[0.1, 0.2, 0.3]] - + + vector = [0.1, 0.2, 0.3] + # Act - result = client.request(vectors=vectors) - + result = client.request(vector=vector) + # Assert assert result == ["test_chunk"] client.call.assert_called_once_with( user="trustgraph", collection="default", - vectors=vectors, + vector=vector, limit=10, timeout=300 ) @@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = [] client.call = MagicMock(return_value=mock_response) - + # Act - result = client.request(vectors=[[0.1, 0.2, 0.3]]) - + result = client.request(vector=[0.1, 0.2, 0.3]) + # Assert assert result == [] @@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = None client.call = MagicMock(return_value=mock_response) - + # Act - result = client.request(vectors=[[0.1, 0.2, 0.3]]) - + result = client.request(vector=[0.1, 0.2, 0.3]) + # Assert assert result is None @@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient: mock_response = MagicMock() mock_response.chunks = ["chunk1"] client.call = MagicMock(return_value=mock_response) - + # Act client.request( - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], timeout=600 ) - + # Assert assert client.call.call_args[1]["timeout"] == 600 \ No newline at end of file diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index 96c9c427..8c37ca32 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -98,7 +98,7 @@ def sample_graph_embeddings(): entities=[ EntityEmbeddings( entity=Term(type=IRI, iri="http://example.org/john"), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) ] ) diff --git a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py index ca43bf83..f4e456cb 100644 --- a/tests/unit/test_embeddings/test_fastembed_dynamic_model.py +++ b/tests/unit/test_embeddings/test_fastembed_dynamic_model.py @@ -108,7 +108,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase): # Assert mock_fastembed_instance.embed.assert_called_once_with(["test text"]) assert processor.cached_model_name == "test-model" # Still using default - assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @patch('trustgraph.embeddings.fastembed.processor.TextEmbedding') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') diff --git a/tests/unit/test_embeddings/test_ollama_dynamic_model.py b/tests/unit/test_embeddings/test_ollama_dynamic_model.py index 80e1de4e..d52a58c6 100644 --- a/tests/unit/test_embeddings/test_ollama_dynamic_model.py +++ b/tests/unit/test_embeddings/test_ollama_dynamic_model.py @@ -60,7 +60,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): model="test-model", input=["test text"] ) - assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -86,7 +86,7 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase): model="custom-model", input=["test text"] ) - assert result == [[[0.1, 0.2, 0.3, 0.4, 0.5]]] + assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] @patch('trustgraph.embeddings.ollama.processor.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index b4c954d8..1cddce97 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.doc_embeddings.milvus.service import Processor -from trustgraph.schema import DocumentEmbeddingsRequest +from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch class TestMilvusDocEmbeddingsQueryProcessor: @@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 ) return query @@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor: [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5 ) - # Verify results are document chunks + # Verify results are ChunkMatch objects assert len(result) == 3 - assert result[0] == "First document chunk" - assert result[1] == "Second document chunk" - assert result[2] == "Third document chunk" + assert isinstance(result[0], ChunkMatch) + assert result[0].chunk_id == "First document chunk" + assert result[1].chunk_id == "Second document chunk" + assert result[2].chunk_id == "Third document chunk" @pytest.mark.asyncio - async def test_query_document_embeddings_multiple_vectors(self, processor): - """Test querying document embeddings with multiple vectors""" + async def test_query_document_embeddings_longer_vector(self, processor): + """Test querying document embeddings with a longer vector""" query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=3 ) - - # Mock search results - different results for each vector - mock_results_1 = [ - {"entity": {"chunk_id": "Document from first vector"}}, - {"entity": {"chunk_id": "Another doc from first vector"}}, + + # Mock search results + mock_results = [ + {"entity": {"chunk_id": "First document"}}, + {"entity": {"chunk_id": "Second document"}}, + {"entity": {"chunk_id": "Third document"}}, ] - mock_results_2 = [ - {"entity": {"chunk_id": "Document from second vector"}}, - ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Verify search was called twice with correct parameters including user/collection - expected_calls = [ - (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}), - (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}), - ] - assert processor.vecstore.search.call_count == 2 - for i, (expected_args, expected_kwargs) in enumerate(expected_calls): - actual_call = processor.vecstore.search.call_args_list[i] - assert actual_call[0] == expected_args - assert actual_call[1] == expected_kwargs - - # Verify results from all vectors are combined + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3 + ) + + # Verify results are ChunkMatch objects assert len(result) == 3 - assert "Document from first vector" in result - assert "Another doc from first vector" in result - assert "Document from second vector" in result + chunk_ids = [r.chunk_id for r in result] + assert "First document" in chunk_ids + assert "Second document" in chunk_ids + assert "Third document" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_with_limit(self, processor): @@ -141,7 +135,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=2 ) @@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[], + vector=[], limit=5 ) @@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -211,7 +205,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -225,11 +219,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify Unicode content is preserved + # Verify Unicode content is preserved in ChunkMatch objects assert len(result) == 3 - assert "Document with Unicode: éñ中文🚀" in result - assert "Regular ASCII document" in result - assert "Document with émojis: 😀🎉" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document with Unicode: éñ中文🚀" in chunk_ids + assert "Regular ASCII document" in chunk_ids + assert "Document with émojis: 😀🎉" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_large_documents(self, processor): @@ -237,7 +232,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -251,10 +246,11 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify large content is preserved + # Verify large content is preserved in ChunkMatch objects assert len(result) == 2 - assert large_doc in result - assert "Small document" in result + chunk_ids = [r.chunk_id for r in result] + assert large_doc in chunk_ids + assert "Small document" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_special_characters(self, processor): @@ -262,7 +258,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -276,11 +272,12 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify special characters are preserved + # Verify special characters are preserved in ChunkMatch objects assert len(result) == 3 - assert "Document with \"quotes\" and 'apostrophes'" in result - assert "Document with\nnewlines\tand\ttabs" in result - assert "Document with special chars: @#$%^&*()" in result + chunk_ids = [r.chunk_id for r in result] + assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids + assert "Document with\nnewlines\tand\ttabs" in chunk_ids + assert "Document with special chars: @#$%^&*()" in chunk_ids @pytest.mark.asyncio async def test_query_document_embeddings_zero_limit(self, processor): @@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=0 ) @@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=-1 ) @@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor: query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ], + vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector limit=5 ) - - # Mock search results for each vector - mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}] - mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}] - mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] - + + # Mock search results + mock_results = [ + {"entity": {"chunk_id": "Document 1"}}, + {"entity": {"chunk_id": "Document 2"}}, + ] + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Verify all vectors were searched - assert processor.vecstore.search.call_count == 3 - - # Verify results from all dimensions - assert len(result) == 3 - assert "Document from 2D vector" in result - assert "Document from 4D vector" in result - assert "Document from 3D vector" in result + + # Verify search was called with the vector + processor.vecstore.search.assert_called_once() + + # Verify results are ChunkMatch objects + assert len(result) == 2 + chunk_ids = [r.chunk_id for r in result] + assert "Document 1" in chunk_ids + assert "Document 2" in chunk_ids @pytest.mark.asyncio - async def test_query_document_embeddings_duplicate_documents(self, processor): - """Test querying document embeddings with duplicate documents in results""" + async def test_query_document_embeddings_multiple_results(self, processor): + """Test querying document embeddings with multiple results""" query = DocumentEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 ) - - # Mock search results with duplicates across vectors - mock_results_1 = [ + + # Mock search results with multiple documents + mock_results = [ {"entity": {"chunk_id": "Document A"}}, {"entity": {"chunk_id": "Document B"}}, - ] - mock_results_2 = [ - {"entity": {"chunk_id": "Document B"}}, # Duplicate {"entity": {"chunk_id": "Document C"}}, ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_document_embeddings(query) - - # Note: Unlike graph embeddings, doc embeddings don't deduplicate - # This preserves ranking and allows multiple occurrences - assert len(result) == 4 - assert result.count("Document B") == 2 # Should appear twice - assert "Document A" in result - assert "Document C" in result + + # Verify results are ChunkMatch objects + assert len(result) == 3 + chunk_ids = [r.chunk_id for r in result] + assert "Document A" in chunk_ids + assert "Document B" in chunk_ids + assert "Document C" in chunk_ids def test_add_args_method(self): """Test that add_args properly configures argument parser""" diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py index 4b067743..04a93c17 100644 --- a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -103,7 +103,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_single_vector(self, processor): """Test querying document embeddings with a single vector""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' @@ -179,7 +179,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_limit_handling(self, processor): """Test that query respects the limit parameter""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' @@ -208,7 +208,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_zero_limit(self, processor): """Test querying with zero limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 0 message.user = 'test_user' message.collection = 'test_collection' @@ -226,7 +226,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_negative_limit(self, processor): """Test querying with negative limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = -1 message.user = 'test_user' message.collection = 'test_collection' @@ -285,7 +285,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_empty_vectors_list(self, processor): """Test querying with empty vectors list""" message = MagicMock() - message.vectors = [] + message.vector = [] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -304,7 +304,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_no_results(self, processor): """Test querying when index returns no results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -325,7 +325,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_unicode_content(self, processor): """Test querying document embeddings with Unicode content results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' @@ -351,7 +351,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_large_content(self, processor): """Test querying document embeddings with large content results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 1 message.user = 'test_user' message.collection = 'test_collection' @@ -377,7 +377,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_mixed_content_types(self, processor): """Test querying document embeddings with mixed content types""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -409,7 +409,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_exception_handling(self, processor): """Test that exceptions are properly raised""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -425,7 +425,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: async def test_query_document_embeddings_index_access_failure(self, processor): """Test handling of index access failure""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index 204c4cc3..1d2f0e6d 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.query.doc_embeddings.qdrant.service import Processor +from trustgraph.schema import ChunkMatch class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): @@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 5 mock_message.user = 'test_user' mock_message.collection = 'test_collection' @@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): with_payload=True ) - # Verify result contains expected documents + # Verify result contains expected ChunkMatch objects assert len(result) == 2 - # Results should be strings (document chunks) - assert isinstance(result[0], str) - assert isinstance(result[1], str) + # Results should be ChunkMatch objects + assert isinstance(result[0], ChunkMatch) + assert isinstance(result[1], ChunkMatch) # Verify content - assert result[0] == 'first document chunk' - assert result[1] == 'second document chunk' + assert result[0].chunk_id == 'first document chunk' + assert result[1].chunk_id == 'second document chunk' @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') - async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client): - """Test querying document embeddings with multiple vectors""" + async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client): + """Test querying document embeddings returns multiple results""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - - # Mock query responses for different vectors + + # Mock query response with multiple results mock_point1 = MagicMock() - mock_point1.payload = {'chunk_id': 'document from vector 1'} + mock_point1.payload = {'chunk_id': 'document chunk 1'} mock_point2 = MagicMock() - mock_point2.payload = {'chunk_id': 'document from vector 2'} + mock_point2.payload = {'chunk_id': 'document chunk 2'} mock_point3 = MagicMock() - mock_point3.payload = {'chunk_id': 'another document from vector 2'} - - mock_response1 = MagicMock() - mock_response1.points = [mock_point1] - mock_response2 = MagicMock() - mock_response2.points = [mock_point2, mock_point3] - mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2] - + mock_point3.payload = {'chunk_id': 'document chunk 3'} + + mock_response = MagicMock() + mock_response.points = [mock_point1, mock_point2, mock_point3] + mock_qdrant_instance.query_points.return_value = mock_response + config = { 'taskgroup': AsyncMock(), 'id': 'test-processor' } processor = Processor(**config) - - # Create mock message with multiple vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' - + # Act result = await processor.query_document_embeddings(mock_message) # Assert - # Verify query was called twice - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 - # Verify both collections were queried (both 2-dimensional vectors) + # Verify collection was queried correctly expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection - assert calls[1][1]['collection_name'] == expected_collection assert calls[0][1]['query'] == [0.1, 0.2] - assert calls[1][1]['query'] == [0.3, 0.4] - - # Verify results from both vectors are combined + + # Verify results are ChunkMatch objects assert len(result) == 3 - assert 'document from vector 1' in result - assert 'document from vector 2' in result - assert 'another document from vector 2' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document chunk 1' in chunk_ids + assert 'document chunk 2' in chunk_ids + assert 'document chunk 3' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 3 # Should only return 3 results mock_message.user = 'limit_user' mock_message.collection = 'limit_collection' @@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'empty_user' mock_message.collection = 'empty_collection' @@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client): - """Test querying document embeddings with different vector dimensions""" + """Test querying document embeddings with a higher dimension vector""" # Arrange mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance - - # Mock query responses + + # Mock query response mock_point1 = MagicMock() - mock_point1.payload = {'chunk_id': 'document from 2D vector'} + mock_point1.payload = {'chunk_id': 'document from 5D vector'} mock_point2 = MagicMock() - mock_point2.payload = {'chunk_id': 'document from 3D vector'} - - mock_response1 = MagicMock() - mock_response1.points = [mock_point1] - mock_response2 = MagicMock() - mock_response2.points = [mock_point2] - mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2] - + mock_point2.payload = {'chunk_id': 'another 5D document'} + + mock_response = MagicMock() + mock_response.points = [mock_point1, mock_point2] + mock_qdrant_instance.query_points.return_value = mock_response + config = { 'taskgroup': AsyncMock(), 'id': 'test-processor' } processor = Processor(**config) - - # Create mock message with different dimension vectors + + # Create mock message with 5D vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D + mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector mock_message.limit = 5 mock_message.user = 'dim_user' mock_message.collection = 'dim_collection' - + # Act result = await processor.query_document_embeddings(mock_message) # Assert - # Verify query was called twice with different collections - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once with correct collection + assert mock_qdrant_instance.query_points.call_count == 1 calls = mock_qdrant_instance.query_points.call_args_list - # First call should use 2D collection - assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions - assert calls[0][1]['query'] == [0.1, 0.2] + # Call should use 5D collection + assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions + assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5] - # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions - assert calls[1][1]['query'] == [0.3, 0.4, 0.5] - - # Verify results + # Verify results are ChunkMatch objects assert len(result) == 2 - assert 'document from 2D vector' in result - assert 'document from 3D vector' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document from 5D vector' in chunk_ids + assert 'another 5D document' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'utf8_user' mock_message.collection = 'utf8_collection' @@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert assert len(result) == 2 - - # Verify UTF-8 content works correctly - assert 'Document with UTF-8: café, naïve, résumé' in result - assert 'Chinese text: 你好世界' in result + + # Verify UTF-8 content works correctly in ChunkMatch objects + chunk_ids = [r.chunk_id for r in result] + assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids + assert 'Chinese text: 你好世界' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'error_user' mock_message.collection = 'error_collection' @@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with zero limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 0 mock_message.user = 'zero_user' mock_message.collection = 'zero_collection' @@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance.query_points.assert_called_once() call_args = mock_qdrant_instance.query_points.call_args assert call_args[1]['limit'] == 0 - - # Result should contain all returned documents + + # Result should contain all returned documents as ChunkMatch objects assert len(result) == 1 - assert result[0] == 'document chunk' + assert isinstance(result[0], ChunkMatch) + assert result[0].chunk_id == 'document chunk' @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with large limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 1000 # Large limit mock_message.user = 'large_user' mock_message.collection = 'large_collection' @@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): mock_qdrant_instance.query_points.assert_called_once() call_args = mock_qdrant_instance.query_points.call_args assert call_args[1]['limit'] == 1000 - - # Result should contain all available documents + + # Result should contain all available documents as ChunkMatch objects assert len(result) == 2 - assert 'document 1' in result - assert 'document 2' in result + chunk_ids = [r.chunk_id for r in result] + assert 'document 1' in chunk_ids + assert 'document 2' in chunk_ids @patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__') @@ -508,7 +504,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'payload_user' mock_message.collection = 'payload_collection' diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index 21b6e1bf..f2b8be7e 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -6,7 +6,7 @@ import pytest from unittest.mock import MagicMock, patch from trustgraph.query.graph_embeddings.milvus.service import Processor -from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL +from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch class TestMilvusGraphEmbeddingsQueryProcessor: @@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=10 ) return query @@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor: [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10 ) - # Verify results are converted to Term objects + # Verify results are converted to EntityMatch objects assert len(result) == 3 - assert isinstance(result[0], Term) - assert result[0].iri == "http://example.com/entity1" - assert result[0].type == IRI - assert isinstance(result[1], Term) - assert result[1].iri == "http://example.com/entity2" - assert result[1].type == IRI - assert isinstance(result[2], Term) - assert result[2].value == "literal entity" - assert result[2].type == LITERAL + assert isinstance(result[0], EntityMatch) + assert result[0].entity.iri == "http://example.com/entity1" + assert result[0].entity.type == IRI + assert isinstance(result[1], EntityMatch) + assert result[1].entity.iri == "http://example.com/entity2" + assert result[1].entity.type == IRI + assert isinstance(result[2], EntityMatch) + assert result[2].entity.value == "literal entity" + assert result[2].entity.type == LITERAL @pytest.mark.asyncio - async def test_query_graph_embeddings_multiple_vectors(self, processor): - """Test querying graph embeddings with multiple vectors""" + async def test_query_graph_embeddings_multiple_results(self, processor): + """Test querying graph embeddings returns multiple results""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], - limit=3 + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + limit=5 ) - - # Mock search results - different results for each vector - mock_results_1 = [ + + # Mock search results with multiple entities + mock_results = [ {"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity2"}}, - ] - mock_results_2 = [ - {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate {"entity": {"entity": "http://example.com/entity3"}}, ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify search was called twice with correct parameters including user/collection - expected_calls = [ - (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}), - (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}), - ] - assert processor.vecstore.search.call_count == 2 - for i, (expected_args, expected_kwargs) in enumerate(expected_calls): - actual_call = processor.vecstore.search.call_args_list[i] - assert actual_call[0] == expected_args - assert actual_call[1] == expected_kwargs - - # Verify results are deduplicated and limited + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10 + ) + + # Verify results are EntityMatch objects assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] + entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] assert "http://example.com/entity1" in entity_values assert "http://example.com/entity2" in entity_values assert "http://example.com/entity3" in entity_values @@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=2 ) @@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert len(result) == 2 @pytest.mark.asyncio - async def test_query_graph_embeddings_deduplication(self, processor): - """Test that duplicate entities are properly deduplicated""" + async def test_query_graph_embeddings_preserves_order(self, processor): + """Test that query results preserve order from the vector store""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], limit=5 ) - - # Mock search results with duplicates - mock_results_1 = [ - {"entity": {"entity": "http://example.com/entity1"}}, - {"entity": {"entity": "http://example.com/entity2"}}, - ] - mock_results_2 = [ - {"entity": {"entity": "http://example.com/entity2"}}, # Duplicate - {"entity": {"entity": "http://example.com/entity1"}}, # Duplicate - {"entity": {"entity": "http://example.com/entity3"}}, # New - ] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] - - result = await processor.query_graph_embeddings(query) - - # Verify duplicates are removed - assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] - assert len(set(entity_values)) == 3 # All unique - assert "http://example.com/entity1" in entity_values - assert "http://example.com/entity2" in entity_values - assert "http://example.com/entity3" in entity_values - @pytest.mark.asyncio - async def test_query_graph_embeddings_early_termination_on_limit(self, processor): - """Test that querying stops early when limit is reached""" - query = GraphEmbeddingsRequest( - user='test_user', - collection='test_collection', - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], - limit=2 - ) - - # Mock search results - first vector returns enough results - mock_results_1 = [ + # Mock search results in specific order + mock_results = [ {"entity": {"entity": "http://example.com/entity1"}}, {"entity": {"entity": "http://example.com/entity2"}}, {"entity": {"entity": "http://example.com/entity3"}}, ] - processor.vecstore.search.return_value = mock_results_1 - + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify only first vector was searched (limit reached) - processor.vecstore.search.assert_called_once_with( - [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4 + + # Verify results are in the same order as returned by the store + assert len(result) == 3 + assert result[0].entity.iri == "http://example.com/entity1" + assert result[1].entity.iri == "http://example.com/entity2" + assert result[2].entity.iri == "http://example.com/entity3" + + @pytest.mark.asyncio + async def test_query_graph_embeddings_results_limited(self, processor): + """Test that results are properly limited when store returns more than requested""" + query = GraphEmbeddingsRequest( + user='test_user', + collection='test_collection', + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + limit=2 ) - - # Verify results are limited + + # Mock search results - returns more results than limit + mock_results = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + {"entity": {"entity": "http://example.com/entity3"}}, + ] + processor.vecstore.search.return_value = mock_results + + result = await processor.query_graph_embeddings(query) + + # Verify search was called with the full vector + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4 + ) + + # Verify results are limited to requested amount assert len(result) == 2 @pytest.mark.asyncio @@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[], + vector=[], limit=5 ) @@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify all results are properly typed assert len(result) == 4 - + # Check URI entities - uri_results = [r for r in result if r.type == IRI] + uri_results = [r for r in result if r.entity.type == IRI] assert len(uri_results) == 2 - uri_values = [r.iri for r in uri_results] + uri_values = [r.entity.iri for r in uri_results] assert "http://example.com/uri_entity" in uri_values assert "https://example.com/another_uri" in uri_values - + # Check literal entities - literal_results = [r for r in result if not r.type == IRI] + literal_results = [r for r in result if not r.entity.type == IRI] assert len(literal_results) == 2 - literal_values = [r.value for r in literal_results] + literal_values = [r.entity.value for r in literal_results] assert "literal entity text" in literal_values assert "another literal" in literal_values @@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=5 ) @@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], limit=0 ) @@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor: assert len(result) == 0 @pytest.mark.asyncio - async def test_query_graph_embeddings_different_vector_dimensions(self, processor): - """Test querying graph embeddings with different vector dimensions""" + async def test_query_graph_embeddings_longer_vector(self, processor): + """Test querying graph embeddings with a longer vector""" query = GraphEmbeddingsRequest( user='test_user', collection='test_collection', - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ], + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], limit=5 ) - - # Mock search results for each vector - mock_results_1 = [{"entity": {"entity": "entity_2d"}}] - mock_results_2 = [{"entity": {"entity": "entity_4d"}}] - mock_results_3 = [{"entity": {"entity": "entity_3d"}}] - processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] - + + # Mock search results + mock_results = [ + {"entity": {"entity": "http://example.com/entity1"}}, + {"entity": {"entity": "http://example.com/entity2"}}, + ] + processor.vecstore.search.return_value = mock_results + result = await processor.query_graph_embeddings(query) - - # Verify all vectors were searched - assert processor.vecstore.search.call_count == 3 - - # Verify results from all dimensions - assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] - assert "entity_2d" in entity_values - assert "entity_4d" in entity_values - assert "entity_3d" in entity_values \ No newline at end of file + + # Verify search was called once with the full vector + processor.vecstore.search.assert_called_once() + + # Verify results + assert len(result) == 2 + entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] + assert "http://example.com/entity1" in entity_values + assert "http://example.com/entity2" in entity_values \ No newline at end of file diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 1b243113..2c1a673a 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True) from trustgraph.query.graph_embeddings.pinecone.service import Processor -from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.schema import Term, IRI, LITERAL, EntityMatch class TestPineconeGraphEmbeddingsQueryProcessor: @@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: def mock_query_message(self): """Create a mock query message for testing""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_single_vector(self, processor): """Test querying graph embeddings with a single vector""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' @@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor: include_metadata=True ) - # Verify results + # Verify results use EntityMatch structure assert len(entities) == 3 - assert entities[0].value == 'http://example.org/entity1' - assert entities[0].type == IRI - assert entities[1].value == 'entity2' - assert entities[1].type == LITERAL - assert entities[2].value == 'http://example.org/entity3' - assert entities[2].type == IRI + assert entities[0].entity.iri == 'http://example.org/entity1' + assert entities[0].entity.type == IRI + assert entities[1].entity.value == 'entity2' + assert entities[1].entity.type == LITERAL + assert entities[2].entity.iri == 'http://example.org/entity3' + assert entities[2].entity.type == IRI @pytest.mark.asyncio - async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message): - """Test querying graph embeddings with multiple vectors""" + async def test_query_graph_embeddings_basic(self, processor, mock_query_message): + """Test basic graph embeddings query""" # Mock index and query results mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # First query results - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query results with distinct entities + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), - MagicMock(metadata={'entity': 'entity2'}) - ] - - # Second query results - mock_results2 = MagicMock() - mock_results2.matches = [ - MagicMock(metadata={'entity': 'entity2'}), # Duplicate + MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity3'}) ] - - mock_index.query.side_effect = [mock_results1, mock_results2] - + + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(mock_query_message) - - # Verify both queries were made - assert mock_index.query.call_count == 2 - - # Verify deduplication occurred - entity_values = [e.value for e in entities] + + # Verify query was made once + assert mock_index.query.call_count == 1 + + # Verify results with EntityMatch structure + entity_values = [e.entity.value for e in entities] assert len(entity_values) == 3 assert 'entity1' in entity_values assert 'entity2' in entity_values @@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_limit_handling(self, processor): """Test that query respects the limit parameter""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' @@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_zero_limit(self, processor): """Test querying with zero limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 0 message.user = 'test_user' message.collection = 'test_collection' @@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_negative_limit(self, processor): """Test querying with negative limit returns empty results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = -1 message.user = 'test_user' message.collection = 'test_collection' @@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor: assert entities == [] @pytest.mark.asyncio - async def test_query_graph_embeddings_different_vector_dimensions(self, processor): - """Test querying with vectors of different dimensions using same index""" + async def test_query_graph_embeddings_2d_vector(self, processor): + """Test querying with a 2D vector""" message = MagicMock() - message.vectors = [ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6] # 4D vector - ] + message.vector = [0.1, 0.2] # 2D vector message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' - # Mock single index that handles all dimensions + # Mock index mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - # Mock results for different vector queries - mock_results_2d = MagicMock() - mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})] + # Mock results for 2D vector query + mock_results = MagicMock() + mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})] - mock_results_4d = MagicMock() - mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})] - - mock_index.query.side_effect = [mock_results_2d, mock_results_4d] + mock_index.query.return_value = mock_results entities = await processor.query_graph_embeddings(message) - # Verify different indexes used for different dimensions - assert processor.pinecone.Index.call_count == 2 - index_calls = processor.pinecone.Index.call_args_list - index_names = [call[0][0] for call in index_calls] - assert "t-test_user-test_collection-2" in index_names # 2D vector - assert "t-test_user-test_collection-4" in index_names # 4D vector + # Verify correct index used for 2D vector + processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2") - # Verify both queries were made - assert mock_index.query.call_count == 2 + # Verify query was made + assert mock_index.query.call_count == 1 - # Verify results from both dimensions - entity_values = [e.value for e in entities] + # Verify results with EntityMatch structure + entity_values = [e.entity.value for e in entities] assert 'entity_2d' in entity_values - assert 'entity_4d' in entity_values @pytest.mark.asyncio async def test_query_graph_embeddings_empty_vectors_list(self, processor): """Test querying with empty vectors list""" message = MagicMock() - message.vectors = [] + message.vector = [] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_no_results(self, processor): """Test querying when index returns no results""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' @@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor: assert entities == [] @pytest.mark.asyncio - async def test_query_graph_embeddings_deduplication_across_vectors(self, processor): - """Test that deduplication works correctly across multiple vector queries""" + async def test_query_graph_embeddings_deduplication_in_results(self, processor): + """Test that deduplication works correctly within query results""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 3 message.user = 'test_user' message.collection = 'test_collection' - + mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # Both queries return overlapping results - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query returns results with some duplicates + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity2'}), + MagicMock(metadata={'entity': 'entity1'}), # Duplicate MagicMock(metadata={'entity': 'entity3'}), - MagicMock(metadata={'entity': 'entity4'}) - ] - - mock_results2 = MagicMock() - mock_results2.matches = [ MagicMock(metadata={'entity': 'entity2'}), # Duplicate - MagicMock(metadata={'entity': 'entity3'}), # Duplicate - MagicMock(metadata={'entity': 'entity5'}) ] - - mock_index.query.side_effect = [mock_results1, mock_results2] - + + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(message) - + # Should get exactly 3 unique entities (respecting limit) assert len(entities) == 3 - entity_values = [e.value for e in entities] + entity_values = [e.entity.value for e in entities] assert len(set(entity_values)) == 3 # All unique @pytest.mark.asyncio - async def test_query_graph_embeddings_early_termination_on_limit(self, processor): - """Test that querying stops early when limit is reached""" + async def test_query_graph_embeddings_respects_limit(self, processor): + """Test that query respects limit parameter""" message = MagicMock() - message.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] + message.vector = [0.1, 0.2, 0.3] message.limit = 2 message.user = 'test_user' message.collection = 'test_collection' - + mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index - - # First query returns enough results to meet limit - mock_results1 = MagicMock() - mock_results1.matches = [ + + # Query returns more results than limit + mock_results = MagicMock() + mock_results.matches = [ MagicMock(metadata={'entity': 'entity1'}), MagicMock(metadata={'entity': 'entity2'}), MagicMock(metadata={'entity': 'entity3'}) ] - mock_index.query.return_value = mock_results1 - + mock_index.query.return_value = mock_results + entities = await processor.query_graph_embeddings(message) - - # Should only make one query since limit was reached + + # Should only return 2 entities (respecting limit) mock_index.query.assert_called_once() assert len(entities) == 2 @@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: async def test_query_graph_embeddings_exception_handling(self, processor): """Test that exceptions are properly raised""" message = MagicMock() - message.vectors = [[0.1, 0.2, 0.3]] + message.vector = [0.1, 0.2, 0.3] message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 1760c4c1..9362a8dd 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.query.graph_embeddings.qdrant.service import Processor -from trustgraph.schema import IRI, LITERAL +from trustgraph.schema import IRI, LITERAL, EntityMatch class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): @@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 5 mock_message.user = 'test_user' mock_message.collection = 'test_collection' @@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): with_payload=True ) - # Verify result contains expected entities + # Verify result contains expected EntityMatch objects assert len(result) == 2 - assert all(hasattr(entity, 'value') for entity in result) - entity_values = [entity.value for entity in result] + assert all(isinstance(entity, EntityMatch) for entity in result) + entity_values = [entity.entity.value for entity in result] assert 'entity1' in entity_values assert 'entity2' in entity_values @@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): } processor = Processor(**config) - - # Create mock message with multiple vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' - + # Act result = await processor.query_graph_embeddings(mock_message) # Assert - # Verify query was called twice - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 - # Verify both collections were queried (both 2-dimensional vectors) + # Verify collection was queried expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection - assert calls[1][1]['collection_name'] == expected_collection assert calls[0][1]['query'] == [0.1, 0.2] - assert calls[1][1]['query'] == [0.3, 0.4] - - # Verify deduplication - entity2 appears in both results but should only appear once - entity_values = [entity.value for entity in result] + + # Verify results with EntityMatch structure + entity_values = [entity.entity.value for entity in result] assert len(set(entity_values)) == len(entity_values) # All unique assert 'entity1' in entity_values assert 'entity2' in entity_values - assert 'entity3' in entity_values @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 3 # Should only return 3 results mock_message.user = 'limit_user' mock_message.collection = 'limit_collection' @@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'empty_user' mock_message.collection = 'empty_collection' @@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): } processor = Processor(**config) - - # Create mock message with different dimension vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D + mock_message.vector = [0.1, 0.2] # 2D vector mock_message.limit = 5 mock_message.user = 'dim_user' mock_message.collection = 'dim_collection' - + # Act result = await processor.query_graph_embeddings(mock_message) # Assert - # Verify query was called twice with different collections - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 calls = mock_qdrant_instance.query_points.call_args_list - # First call should use 2D collection + # Call should use 2D collection assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions assert calls[0][1]['query'] == [0.1, 0.2] - # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions - assert calls[1][1]['query'] == [0.3, 0.4, 0.5] - - # Verify results - entity_values = [entity.value for entity in result] + # Verify results with EntityMatch structure + entity_values = [entity.entity.value for entity in result] assert 'entity2d' in entity_values - assert 'entity3d' in entity_values @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'uri_user' mock_message.collection = 'uri_collection' @@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert assert len(result) == 3 - + # Check URI entities - uri_entities = [entity for entity in result if entity.type == IRI] + uri_entities = [entity for entity in result if entity.entity.type == IRI] assert len(uri_entities) == 2 - uri_values = [entity.iri for entity in uri_entities] + uri_values = [entity.entity.iri for entity in uri_entities] assert 'http://example.com/entity1' in uri_values assert 'https://secure.example.com/entity2' in uri_values # Check regular entities - regular_entities = [entity for entity in result if entity.type == LITERAL] + regular_entities = [entity for entity in result if entity.entity.type == LITERAL] assert len(regular_entities) == 1 - assert regular_entities[0].value == 'regular entity' + assert regular_entities[0].entity.value == 'regular entity' @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'error_user' mock_message.collection = 'error_collection' @@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with zero limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 0 mock_message.user = 'zero_user' mock_message.collection = 'zero_collection' @@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # With zero limit, the logic still adds one entity before checking the limit # So it returns one result (current behavior, not ideal but actual) assert len(result) == 1 - assert result[0].value == 'entity1' + assert result[0].entity.value == 'entity1' @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 51e41524..92a62222 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -175,9 +175,14 @@ class TestQuery: test_vectors = [[0.1, 0.2, 0.3]] mock_embeddings_client.embed.return_value = [test_vectors] - # Mock document embeddings returns chunk_ids - test_chunk_ids = ["doc/c1", "doc/c2"] - mock_doc_embeddings_client.query.return_value = test_chunk_ids + # Mock document embeddings returns ChunkMatch objects + mock_match1 = MagicMock() + mock_match1.chunk_id = "doc/c1" + mock_match1.score = 0.95 + mock_match2 = MagicMock() + mock_match2.chunk_id = "doc/c2" + mock_match2.score = 0.85 + mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] # Initialize Query query = Query( @@ -195,9 +200,9 @@ class TestQuery: # Verify embeddings client was called (now expects list) mock_embeddings_client.embed.assert_called_once_with([test_query]) - # Verify doc embeddings client was called correctly (with extracted vectors) + # Verify doc embeddings client was called correctly (with extracted vector) mock_doc_embeddings_client.query.assert_called_once_with( - test_vectors, + vector=test_vectors, limit=15, user="test_user", collection="test_collection" @@ -218,11 +223,16 @@ class TestQuery: # Mock embeddings and document embeddings responses # New batch format: [[[vectors]]] - get_vector extracts [0] test_vectors = [[0.1, 0.2, 0.3]] - test_chunk_ids = ["doc/c3", "doc/c4"] + mock_match1 = MagicMock() + mock_match1.chunk_id = "doc/c3" + mock_match1.score = 0.9 + mock_match2 = MagicMock() + mock_match2.chunk_id = "doc/c4" + mock_match2.score = 0.8 expected_response = "This is the document RAG response" mock_embeddings_client.embed.return_value = [test_vectors] - mock_doc_embeddings_client.query.return_value = test_chunk_ids + mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] mock_prompt_client.document_prompt.return_value = expected_response # Initialize DocumentRag @@ -245,9 +255,9 @@ class TestQuery: # Verify embeddings client was called (now expects list) mock_embeddings_client.embed.assert_called_once_with(["test query"]) - # Verify doc embeddings client was called (with extracted vectors) + # Verify doc embeddings client was called (with extracted vector) mock_doc_embeddings_client.query.assert_called_once_with( - test_vectors, + vector=test_vectors, limit=10, user="test_user", collection="test_collection" @@ -275,7 +285,10 @@ class TestQuery: # Mock responses (batch format) mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] - mock_doc_embeddings_client.query.return_value = ["doc/c5"] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c5" + mock_match.score = 0.9 + mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Default response" # Initialize DocumentRag @@ -289,9 +302,9 @@ class TestQuery: # Call DocumentRag.query with minimal parameters result = await document_rag.query("simple query") - # Verify default parameters were used (vectors extracted from batch) + # Verify default parameters were used (vector extracted from batch) mock_doc_embeddings_client.query.assert_called_once_with( - [[0.1, 0.2]], + vector=[[0.1, 0.2]], limit=20, # Default doc_limit user="trustgraph", # Default user collection="default" # Default collection @@ -316,7 +329,10 @@ class TestQuery: # Mock responses (batch format) mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]] - mock_doc_embeddings_client.query.return_value = ["doc/c6"] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c6" + mock_match.score = 0.88 + mock_doc_embeddings_client.query.return_value = [mock_match] # Initialize Query with verbose=True query = Query( @@ -347,7 +363,10 @@ class TestQuery: # Mock responses (batch format) mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] - mock_doc_embeddings_client.query.return_value = ["doc/c7"] + mock_match = MagicMock() + mock_match.chunk_id = "doc/c7" + mock_match.score = 0.92 + mock_doc_embeddings_client.query.return_value = [mock_match] mock_prompt_client.document_prompt.return_value = "Verbose RAG response" # Initialize DocumentRag with verbose=True @@ -487,7 +506,13 @@ class TestQuery: final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." mock_embeddings_client.embed.return_value = [query_vectors] - mock_doc_embeddings_client.query.return_value = retrieved_chunk_ids + mock_matches = [] + for chunk_id in retrieved_chunk_ids: + mock_match = MagicMock() + mock_match.chunk_id = chunk_id + mock_match.score = 0.9 + mock_matches.append(mock_match) + mock_doc_embeddings_client.query.return_value = mock_matches mock_prompt_client.document_prompt.return_value = final_response # Initialize DocumentRag @@ -511,7 +536,7 @@ class TestQuery: mock_embeddings_client.embed.assert_called_once_with([query_text]) mock_doc_embeddings_client.query.assert_called_once_with( - query_vectors, + vector=query_vectors, limit=25, user="research_user", collection="ml_knowledge" diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 15b0c82d..e763d089 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -193,12 +193,20 @@ class TestQuery: test_vectors = [[0.1, 0.2, 0.3]] mock_embeddings_client.embed.return_value = [test_vectors] - # Mock entity objects that have string representation + # Mock EntityMatch objects with entity that has string representation mock_entity1 = MagicMock() mock_entity1.__str__ = MagicMock(return_value="entity1") + mock_match1 = MagicMock() + mock_match1.entity = mock_entity1 + mock_match1.score = 0.95 + mock_entity2 = MagicMock() mock_entity2.__str__ = MagicMock(return_value="entity2") - mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2] + mock_match2 = MagicMock() + mock_match2.entity = mock_entity2 + mock_match2.score = 0.85 + + mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2] # Initialize Query query = Query( @@ -216,9 +224,9 @@ class TestQuery: # Verify embeddings client was called (now expects list) mock_embeddings_client.embed.assert_called_once_with([test_query]) - # Verify graph embeddings client was called correctly (with extracted vectors) + # Verify graph embeddings client was called correctly (with extracted vector) mock_graph_embeddings_client.query.assert_called_once_with( - vectors=test_vectors, + vector=test_vectors, limit=25, user="test_user", collection="test_collection" diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py index ef66b741..f9d60541 100644 --- a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -23,11 +23,11 @@ class TestMilvusDocEmbeddingsStorageProcessor: # Create test document embeddings chunk1 = ChunkEmbeddings( chunk_id="This is the first document chunk", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) chunk2 = ChunkEmbeddings( chunk_id="This is the second document chunk", - vectors=[[0.7, 0.8, 0.9]] + vector=[0.7, 0.8, 0.9] ) message.chunks = [chunk1, chunk2] @@ -82,44 +82,34 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + chunk = ChunkEmbeddings( chunk_id="Test document content", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.chunks = [chunk] - + await processor.store_document_embeddings(message) - - # Verify insert was called for each vector with user/collection parameters - expected_calls = [ - ([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'), - ] - - assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): - actual_call = processor.vecstore.insert.call_args_list[i] - assert actual_call[0][0] == expected_vec - assert actual_call[0][1] == expected_doc - assert actual_call[0][2] == expected_user - assert actual_call[0][3] == expected_collection + + # Verify insert was called once for the single chunk with its vector + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection' + ) @pytest.mark.asyncio async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): """Test storing document embeddings for multiple chunks""" await processor.store_document_embeddings(mock_message) - - # Verify insert was called for each vector of each chunk with user/collection parameters + + # Verify insert was called once per chunk with user/collection parameters expected_calls = [ - # Chunk 1 vectors - ([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), - # Chunk 2 vectors + # Chunk 1 - single vector + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), + # Chunk 2 - single vector ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'), ] - - assert processor.vecstore.insert.call_count == 3 + + assert processor.vecstore.insert.call_count == 2 for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec @@ -137,7 +127,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk_id="", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -156,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk_id=None, - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -177,15 +167,15 @@ class TestMilvusDocEmbeddingsStorageProcessor: valid_chunk = ChunkEmbeddings( chunk_id="Valid document content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) empty_chunk = ChunkEmbeddings( chunk_id="", - vectors=[[0.4, 0.5, 0.6]] + vector=[0.4, 0.5, 0.6] ) another_valid = ChunkEmbeddings( chunk_id="Another valid chunk", - vectors=[[0.7, 0.8, 0.9]] + vector=[0.7, 0.8, 0.9] ) message.chunks = [valid_chunk, empty_chunk, another_valid] @@ -229,7 +219,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk_id="Document with no vectors", - vectors=[] + vector=[] ) message.chunks = [chunk] @@ -245,26 +235,31 @@ class TestMilvusDocEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - - chunk = ChunkEmbeddings( - chunk_id="Document with mixed dimensions", - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + + # Each chunk has a single vector of different dimensions + chunk1 = ChunkEmbeddings( + chunk_id="chunk/doc/2d", + vector=[0.1, 0.2] # 2D vector ) - message.chunks = [chunk] - + chunk2 = ChunkEmbeddings( + chunk_id="chunk/doc/4d", + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + chunk3 = ChunkEmbeddings( + chunk_id="chunk/doc/3d", + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.chunks = [chunk1, chunk2, chunk3] + await processor.store_document_embeddings(message) - + # Verify all vectors were inserted regardless of dimension with user/collection parameters expected_calls = [ - ([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'), - ([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'), - ([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'), + ([0.1, 0.2], "chunk/doc/2d", 'test_user', 'test_collection'), + ([0.3, 0.4, 0.5, 0.6], "chunk/doc/4d", 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], "chunk/doc/3d", 'test_user', 'test_collection'), ] - + assert processor.vecstore.insert.call_count == 3 for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] @@ -283,7 +278,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk_id="chunk/doc/unicode-éñ中文🚀", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -306,7 +301,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: long_chunk_id = "chunk/doc/" + "a" * 200 chunk = ChunkEmbeddings( chunk_id=long_chunk_id, - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -327,7 +322,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk_id=" \n\t ", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -358,7 +353,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk_id="Test content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -379,7 +374,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message1.metadata.collection = 'collection1' chunk1 = ChunkEmbeddings( chunk_id="User1 content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message1.chunks = [chunk1] @@ -390,7 +385,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: message2.metadata.collection = 'collection2' chunk2 = ChunkEmbeddings( chunk_id="User2 content", - vectors=[[0.4, 0.5, 0.6]] + vector=[0.4, 0.5, 0.6] ) message2.chunks = [chunk2] @@ -421,7 +416,7 @@ class TestMilvusDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk_id="Special chars test", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] diff --git a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py index fc7c0a79..fec4f87e 100644 --- a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py @@ -27,11 +27,11 @@ class TestPineconeDocEmbeddingsStorageProcessor: # Create test document embeddings chunk1 = ChunkEmbeddings( chunk=b"This is the first document chunk", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) chunk2 = ChunkEmbeddings( chunk=b"This is the second document chunk", - vectors=[[0.7, 0.8, 0.9]] + vector=[0.7, 0.8, 0.9] ) message.chunks = [chunk1, chunk2] @@ -125,7 +125,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.chunks = [chunk] @@ -190,7 +190,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -222,7 +222,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -244,7 +244,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=None, - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -266,7 +266,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"", # Empty bytes - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -286,37 +286,39 @@ class TestPineconeDocEmbeddingsStorageProcessor: message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - chunk = ChunkEmbeddings( - chunk=b"Document with mixed dimensions", - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + # Each chunk has a single vector of different dimensions + chunk1 = ChunkEmbeddings( + chunk=b"Document chunk 1", + vector=[0.1, 0.2] # 2D vector ) - message.chunks = [chunk] - - mock_index_2d = MagicMock() - mock_index_4d = MagicMock() - mock_index_3d = MagicMock() - + chunk2 = ChunkEmbeddings( + chunk=b"Document chunk 2", + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + chunk3 = ChunkEmbeddings( + chunk=b"Document chunk 3", + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.chunks = [chunk1, chunk2, chunk3] + + mock_index = MagicMock() + def mock_index_side_effect(name): # All dimensions now use the same index name pattern - # Different dimensions will be handled within the same index if "test_user" in name and "test_collection" in name: - return mock_index_2d # Just return one mock for all + return mock_index return MagicMock() - + processor.pinecone.Index.side_effect = mock_index_side_effect processor.pinecone.has_index.return_value = True - + with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): await processor.store_document_embeddings(message) # Verify all vectors are now stored in the same index - # (Pinecone can handle mixed dimensions in the same index) - assert processor.pinecone.Index.call_count == 3 # Called once per vector - mock_index_2d.upsert.call_count == 3 # All upserts go to same index + # (Each chunk has a single vector, called once per chunk) + assert processor.pinecone.Index.call_count == 3 # Called once per chunk + assert mock_index.upsert.call_count == 3 # All upserts go to same index @pytest.mark.asyncio async def test_store_document_embeddings_empty_chunks_list(self, processor): @@ -346,7 +348,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Document with no vectors", - vectors=[] + vector=[] ) message.chunks = [chunk] @@ -368,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -393,7 +395,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk=b"Test document content", - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -419,7 +421,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: chunk = ChunkEmbeddings( chunk="Document with Unicode: éñ中文🚀".encode('utf-8'), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] @@ -447,7 +449,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: large_content = "A" * 10000 # 10KB of content chunk = ChunkEmbeddings( chunk=large_content.encode('utf-8'), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.chunks = [chunk] diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index e3498ce9..98d2dab2 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -89,7 +89,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk = MagicMock() mock_chunk.chunk_id = 'doc/c1' # chunk_id instead of chunk bytes - mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions + mock_chunk.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions mock_message.chunks = [mock_chunk] @@ -143,11 +143,11 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk1 = MagicMock() mock_chunk1.chunk_id = 'doc/c1' - mock_chunk1.vectors = [[0.1, 0.2]] + mock_chunk1.vector = [0.1, 0.2] mock_chunk2 = MagicMock() mock_chunk2.chunk_id = 'doc/c2' - mock_chunk2.vectors = [[0.3, 0.4]] + mock_chunk2.vector = [0.3, 0.4] mock_message.chunks = [mock_chunk1, mock_chunk2] @@ -175,8 +175,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client): - """Test storing document embeddings with multiple vectors per chunk""" + async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client): + """Test storing document embeddings with multiple chunks""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True @@ -196,41 +196,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Add collection to known_collections (simulates config push) processor.known_collections[('vector_user', 'vector_collection')] = {} - # Create mock message with chunk having multiple vectors + # Create mock message with multiple chunks, each having a single vector mock_message = MagicMock() mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' - mock_chunk = MagicMock() - mock_chunk.chunk_id = 'doc/multi-vector' - mock_chunk.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] + mock_chunk1 = MagicMock() + mock_chunk1.chunk_id = 'doc/c1' + mock_chunk1.vector = [0.1, 0.2, 0.3] - mock_message.chunks = [mock_chunk] + mock_chunk2 = MagicMock() + mock_chunk2.chunk_id = 'doc/c2' + mock_chunk2.vector = [0.4, 0.5, 0.6] + + mock_chunk3 = MagicMock() + mock_chunk3.chunk_id = 'doc/c3' + mock_chunk3.vector = [0.7, 0.8, 0.9] + + mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3] # Act await processor.store_document_embeddings(mock_message) # Assert - # Should be called 3 times (once per vector) + # Should be called 3 times (once per chunk) assert mock_qdrant_instance.upsert.call_count == 3 # Verify all vectors were processed upsert_calls = mock_qdrant_instance.upsert.call_args_list - expected_vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] + expected_data = [ + ([0.1, 0.2, 0.3], 'doc/c1'), + ([0.4, 0.5, 0.6], 'doc/c2'), + ([0.7, 0.8, 0.9], 'doc/c3') ] for i, call in enumerate(upsert_calls): point = call[1]['points'][0] - assert point.vector == expected_vectors[i] - assert point.payload['chunk_id'] == 'doc/multi-vector' + assert point.vector == expected_data[i][0] + assert point.payload['chunk_id'] == expected_data[i][1] @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') async def test_store_document_embeddings_empty_chunk_id(self, mock_qdrant_client): @@ -256,7 +260,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk_empty = MagicMock() mock_chunk_empty.chunk_id = "" # Empty chunk_id - mock_chunk_empty.vectors = [[0.1, 0.2]] + mock_chunk_empty.vector = [0.1, 0.2] mock_message.chunks = [mock_chunk_empty] @@ -299,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk = MagicMock() mock_chunk.chunk_id = 'doc/test-chunk' - mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions + mock_chunk.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5 dimensions mock_message.chunks = [mock_chunk] @@ -351,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk = MagicMock() mock_chunk.chunk_id = 'doc/test-chunk' - mock_chunk.vectors = [[0.1, 0.2]] + mock_chunk.vector = [0.1, 0.2] mock_message.chunks = [mock_chunk] @@ -389,7 +393,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk1 = MagicMock() mock_chunk1.chunk_id = 'doc/c1' - mock_chunk1.vectors = [[0.1, 0.2, 0.3]] + mock_chunk1.vector = [0.1, 0.2, 0.3] mock_message1.chunks = [mock_chunk1] @@ -407,7 +411,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk2 = MagicMock() mock_chunk2.chunk_id = 'doc/c2' - mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3) + mock_chunk2.vector = [0.4, 0.5, 0.6] # Same dimension (3) mock_message2.chunks = [mock_chunk2] @@ -446,19 +450,20 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Add collection to known_collections (simulates config push) processor.known_collections[('dim_user', 'dim_collection')] = {} - # Create mock message with different dimension vectors + # Create mock message with chunks of different dimensions mock_message = MagicMock() mock_message.metadata.user = 'dim_user' mock_message.metadata.collection = 'dim_collection' - mock_chunk = MagicMock() - mock_chunk.chunk_id = 'doc/dim-test' - mock_chunk.vectors = [ - [0.1, 0.2], # 2 dimensions - [0.3, 0.4, 0.5] # 3 dimensions - ] + mock_chunk1 = MagicMock() + mock_chunk1.chunk_id = 'doc/c1' + mock_chunk1.vector = [0.1, 0.2] # 2 dimensions - mock_message.chunks = [mock_chunk] + mock_chunk2 = MagicMock() + mock_chunk2.chunk_id = 'doc/c2' + mock_chunk2.vector = [0.3, 0.4, 0.5] # 3 dimensions + + mock_message.chunks = [mock_chunk1, mock_chunk2] # Act await processor.store_document_embeddings(mock_message) @@ -526,7 +531,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): mock_chunk = MagicMock() mock_chunk.chunk_id = 'https://trustgraph.ai/doc/my-document/p1/c3' - mock_chunk.vectors = [[0.1, 0.2]] + mock_chunk.vector = [0.1, 0.2] mock_message.chunks = [mock_chunk] diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py index ebb93b62..e4d60adf 100644 --- a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -23,11 +23,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor: # Create test entities with embeddings entity1 = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/entity1'), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) entity2 = EntityEmbeddings( entity=Term(type=LITERAL, value='literal entity'), - vectors=[[0.7, 0.8, 0.9]] + vector=[0.7, 0.8, 0.9] ) message.entities = [entity1, entity2] @@ -82,44 +82,37 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/entity'), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) message.entities = [entity] - + await processor.store_graph_embeddings(message) - - # Verify insert was called for each vector with user/collection parameters - expected_calls = [ - ([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'), - ] - - assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): - actual_call = processor.vecstore.insert.call_args_list[i] - assert actual_call[0][0] == expected_vec - assert actual_call[0][1] == expected_entity - assert actual_call[0][2] == expected_user - assert actual_call[0][3] == expected_collection + + # Verify insert was called once with the full vector + processor.vecstore.insert.assert_called_once() + actual_call = processor.vecstore.insert.call_args_list[0] + assert actual_call[0][0] == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + assert actual_call[0][1] == 'http://example.com/entity' + assert actual_call[0][2] == 'test_user' + assert actual_call[0][3] == 'test_collection' @pytest.mark.asyncio async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): """Test storing graph embeddings for multiple entities""" await processor.store_graph_embeddings(mock_message) - - # Verify insert was called for each vector of each entity with user/collection parameters + + # Verify insert was called once per entity with user/collection parameters expected_calls = [ - # Entity 1 vectors - ([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'), - ([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), - # Entity 2 vectors + # Entity 1 - single vector + ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), + # Entity 2 - single vector ([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'), ] - - assert processor.vecstore.insert.call_count == 3 + + assert processor.vecstore.insert.call_count == 2 for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec @@ -137,7 +130,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Term(type=LITERAL, value=''), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -156,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Term(type=LITERAL, value=None), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -175,17 +168,17 @@ class TestMilvusGraphEmbeddingsStorageProcessor: valid_entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/valid'), - vectors=[[0.1, 0.2, 0.3]], + vector=[0.1, 0.2, 0.3], chunk_id='' ) empty_entity = EntityEmbeddings( entity=Term(type=LITERAL, value=''), - vectors=[[0.4, 0.5, 0.6]], + vector=[0.4, 0.5, 0.6], chunk_id='' ) none_entity = EntityEmbeddings( entity=Term(type=LITERAL, value=None), - vectors=[[0.7, 0.8, 0.9]], + vector=[0.7, 0.8, 0.9], chunk_id='' ) message.entities = [valid_entity, empty_entity, none_entity] @@ -222,7 +215,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/entity'), - vectors=[] + vector=[] ) message.entities = [entity] @@ -238,26 +231,31 @@ class TestMilvusGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - - entity = EntityEmbeddings( - entity=Term(type=IRI, iri='http://example.com/entity'), - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + + # Each entity has a single vector of different dimensions + entity1 = EntityEmbeddings( + entity=Term(type=IRI, iri='http://example.com/entity1'), + vector=[0.1, 0.2] # 2D vector ) - message.entities = [entity] - + entity2 = EntityEmbeddings( + entity=Term(type=IRI, iri='http://example.com/entity2'), + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + entity3 = EntityEmbeddings( + entity=Term(type=IRI, iri='http://example.com/entity3'), + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.entities = [entity1, entity2, entity3] + await processor.store_graph_embeddings(message) - + # Verify all vectors were inserted regardless of dimension expected_calls = [ - ([0.1, 0.2], 'http://example.com/entity'), - ([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity'), - ([0.7, 0.8, 0.9], 'http://example.com/entity'), + ([0.1, 0.2], 'http://example.com/entity1'), + ([0.3, 0.4, 0.5, 0.6], 'http://example.com/entity2'), + ([0.7, 0.8, 0.9], 'http://example.com/entity3'), ] - + assert processor.vecstore.insert.call_count == 3 for i, (expected_vec, expected_entity) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] @@ -274,11 +272,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor: uri_entity = EntityEmbeddings( entity=Term(type=IRI, iri='http://example.com/uri_entity'), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) literal_entity = EntityEmbeddings( entity=Term(type=LITERAL, value='literal entity text'), - vectors=[[0.4, 0.5, 0.6]] + vector=[0.4, 0.5, 0.6] ) message.entities = [uri_entity, literal_entity] diff --git a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py index 0fd0fde3..9ff53f4e 100644 --- a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py @@ -24,16 +24,20 @@ class TestPineconeGraphEmbeddingsStorageProcessor: message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - # Create test entity embeddings + # Create test entity embeddings (each entity has a single vector) entity1 = EntityEmbeddings( entity=Value(value="http://example.org/entity1", is_uri=True), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3] ) entity2 = EntityEmbeddings( - entity=Value(value="entity2", is_uri=False), - vectors=[[0.7, 0.8, 0.9]] + entity=Value(value="http://example.org/entity2", is_uri=True), + vector=[0.4, 0.5, 0.6] ) - message.entities = [entity1, entity2] + entity3 = EntityEmbeddings( + entity=Value(value="entity3", is_uri=False), + vector=[0.7, 0.8, 0.9] + ) + message.entities = [entity1, entity2, entity3] return message @@ -122,27 +126,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor: message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + entity = EntityEmbeddings( entity=Value(value="http://example.org/entity1", is_uri=True), - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] - + # Mock index operations mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index processor.pinecone.has_index.return_value = True - - with patch('uuid.uuid4', side_effect=['id1', 'id2']): + + with patch('uuid.uuid4', side_effect=['id1']): await processor.store_graph_embeddings(message) - + # Verify index name and operations (with dimension suffix) expected_index_name = "t-test_user-test_collection-3" # 3 dimensions processor.pinecone.Index.assert_called_with(expected_index_name) - - # Verify upsert was called for each vector - assert mock_index.upsert.call_count == 2 + + # Verify upsert was called for the single vector + assert mock_index.upsert.call_count == 1 # Check first vector upsert first_call = mock_index.upsert.call_args_list[0] @@ -190,7 +194,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -222,7 +226,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -244,7 +248,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value=None, is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -258,23 +262,27 @@ class TestPineconeGraphEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_graph_embeddings_different_vector_dimensions(self, processor): - """Test storing graph embeddings with different vector dimensions to same index""" + """Test storing graph embeddings with different vector dimensions""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - entity = EntityEmbeddings( - entity=Value(value="test_entity", is_uri=False), - vectors=[ - [0.1, 0.2], # 2D vector - [0.3, 0.4, 0.5, 0.6], # 4D vector - [0.7, 0.8, 0.9] # 3D vector - ] + # Each entity has a single vector of different dimensions + entity1 = EntityEmbeddings( + entity=Value(value="entity1", is_uri=False), + vector=[0.1, 0.2] # 2D vector ) - message.entities = [entity] + entity2 = EntityEmbeddings( + entity=Value(value="entity2", is_uri=False), + vector=[0.3, 0.4, 0.5, 0.6] # 4D vector + ) + entity3 = EntityEmbeddings( + entity=Value(value="entity3", is_uri=False), + vector=[0.7, 0.8, 0.9] # 3D vector + ) + message.entities = [entity1, entity2, entity3] - # All vectors now use the same index (no dimension in name) mock_index = MagicMock() processor.pinecone.Index.return_value = mock_index processor.pinecone.has_index.return_value = True @@ -322,7 +330,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[] + vector=[] ) message.entities = [entity] @@ -344,7 +352,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] @@ -369,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) message.entities = [entity] diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index 8b1a710a..3541ccd4 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -70,7 +70,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_entity = MagicMock() mock_entity.entity.type = IRI mock_entity.entity.iri = 'test_entity' - mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions + mock_entity.vector = [0.1, 0.2, 0.3] # Single vector with 3 dimensions mock_message.entities = [mock_entity] @@ -124,12 +124,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_entity1 = MagicMock() mock_entity1.entity.type = IRI mock_entity1.entity.iri = 'entity_one' - mock_entity1.vectors = [[0.1, 0.2]] + mock_entity1.vector = [0.1, 0.2] mock_entity2 = MagicMock() mock_entity2.entity.type = IRI mock_entity2.entity.iri = 'entity_two' - mock_entity2.vectors = [[0.3, 0.4]] + mock_entity2.vector = [0.3, 0.4] mock_message.entities = [mock_entity1, mock_entity2] @@ -157,14 +157,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') - async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client): - """Test storing graph embeddings with multiple vectors per entity""" + async def test_store_graph_embeddings_three_entities(self, mock_uuid, mock_qdrant_client): + """Test storing graph embeddings with three entities""" # Arrange mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance mock_uuid.uuid4.return_value.return_value = 'test-uuid' - + config = { 'store_uri': 'http://localhost:6333', 'api_key': 'test-api-key', @@ -177,42 +177,48 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Add collection to known_collections (simulates config push) processor.known_collections[('vector_user', 'vector_collection')] = {} - # Create mock message with entity having multiple vectors + # Create mock message with three entities mock_message = MagicMock() mock_message.metadata.user = 'vector_user' mock_message.metadata.collection = 'vector_collection' - - mock_entity = MagicMock() - mock_entity.entity.type = IRI - mock_entity.entity.iri = 'multi_vector_entity' - mock_entity.vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] - ] - - mock_message.entities = [mock_entity] - + + mock_entity1 = MagicMock() + mock_entity1.entity.type = IRI + mock_entity1.entity.iri = 'entity_one' + mock_entity1.vector = [0.1, 0.2, 0.3] + + mock_entity2 = MagicMock() + mock_entity2.entity.type = IRI + mock_entity2.entity.iri = 'entity_two' + mock_entity2.vector = [0.4, 0.5, 0.6] + + mock_entity3 = MagicMock() + mock_entity3.entity.type = IRI + mock_entity3.entity.iri = 'entity_three' + mock_entity3.vector = [0.7, 0.8, 0.9] + + mock_message.entities = [mock_entity1, mock_entity2, mock_entity3] + # Act await processor.store_graph_embeddings(mock_message) # Assert - # Should be called 3 times (once per vector) + # Should be called 3 times (once per entity) assert mock_qdrant_instance.upsert.call_count == 3 - - # Verify all vectors were processed + + # Verify all entities were processed upsert_calls = mock_qdrant_instance.upsert.call_args_list - - expected_vectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - [0.7, 0.8, 0.9] + + expected_data = [ + ([0.1, 0.2, 0.3], 'entity_one'), + ([0.4, 0.5, 0.6], 'entity_two'), + ([0.7, 0.8, 0.9], 'entity_three') ] - + for i, call in enumerate(upsert_calls): point = call[1]['points'][0] - assert point.vector == expected_vectors[i] - assert point.payload['entity'] == 'multi_vector_entity' + assert point.vector == expected_data[i][0] + assert point.payload['entity'] == expected_data[i][1] @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client): @@ -238,11 +244,11 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_entity_empty = MagicMock() mock_entity_empty.entity.type = LITERAL mock_entity_empty.entity.value = "" # Empty string - mock_entity_empty.vectors = [[0.1, 0.2]] + mock_entity_empty.vector = [0.1, 0.2] mock_entity_none = MagicMock() mock_entity_none.entity = None # None entity - mock_entity_none.vectors = [[0.3, 0.4]] + mock_entity_none.vector = [0.3, 0.4] mock_message.entities = [mock_entity_empty, mock_entity_none] diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index b4c5a5b4..e1c8f3b1 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -197,7 +197,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): index_name='customer_id', index_value=['CUST001'], text='CUST001', - vectors=[[0.1, 0.2, 0.3]] + vector=[0.1, 0.2, 0.3] ) embeddings_msg = RowEmbeddings( @@ -227,8 +227,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.uuid') - async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client): - """Test processing embeddings with multiple vectors""" + async def test_on_embeddings_single_vector(self, mock_uuid, mock_qdrant_client): + """Test processing embeddings with a single vector""" from trustgraph.storage.row_embeddings.qdrant.write import Processor from trustgraph.schema import RowEmbeddings, RowIndexEmbedding @@ -250,12 +250,12 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): metadata.collection = 'test_collection' metadata.id = 'doc-123' - # Embedding with multiple vectors + # Embedding with a single 6D vector embedding = RowIndexEmbedding( index_name='name', index_value=['John Doe'], text='John Doe', - vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6] ) embeddings_msg = RowEmbeddings( @@ -269,8 +269,8 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) - # Should be called 3 times (once per vector) - assert mock_qdrant_instance.upsert.call_count == 3 + # Should be called once for the single embedding + assert mock_qdrant_instance.upsert.call_count == 1 @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client): @@ -299,7 +299,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): index_name='id', index_value=['123'], text='123', - vectors=[] # Empty vectors + vector=[] # Empty vector ) embeddings_msg = RowEmbeddings( @@ -342,7 +342,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): index_name='id', index_value=['123'], text='123', - vectors=[[0.1, 0.2]] + vector=[0.1, 0.2] ) embeddings_msg = RowEmbeddings( diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index b4d7aac7..2ff37307 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -612,12 +612,12 @@ class AsyncFlowInstance: print(f"{entity['name']}: {entity['score']}") ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request_data = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -810,12 +810,12 @@ class AsyncFlowInstance: print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})") ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request_data = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 843c5979..99938d5b 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -282,12 +282,12 @@ class AsyncSocketFlowInstance: async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -352,12 +352,12 @@ class AsyncSocketFlowInstance: limit: int = 10, **kwargs ): """Query row embeddings for semantic search on structured data""" - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index f2dad323..49e2f9fa 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -602,13 +602,13 @@ class FlowInstance: ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] # Query graph embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -648,13 +648,13 @@ class FlowInstance: ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] # Query document embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -1362,13 +1362,13 @@ class FlowInstance: ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] # Query row embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index d68d9e98..113ebe35 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -649,12 +649,12 @@ class SocketFlowInstance: ) ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -698,12 +698,12 @@ class SocketFlowInstance: # results contains {"chunk_ids": ["doc1/p0/c0", ...]} ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -936,12 +936,12 @@ class SocketFlowInstance: ) ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py index d403ff21..dd985eab 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal logger = logging.getLogger(__name__) class DocumentEmbeddingsClient(RequestResponse): - async def query(self, vectors, limit=20, user="trustgraph", + async def query(self, vector, limit=20, user="trustgraph", collection="default", timeout=30): resp = await self.request( DocumentEmbeddingsRequest( - vectors = vectors, + vector = vector, limit = limit, user = user, collection = collection @@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - return resp.chunk_ids + # Return ChunkMatch objects with chunk_id and score + return resp.chunks class DocumentEmbeddingsClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index 013847d4..b8979776 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): docs = await self.query_document_embeddings(request) logger.debug("Sending document embeddings query response...") - r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None) + r = DocumentEmbeddingsResponse(chunks=docs, error=None) await flow("response").send(r, properties={"id": id}) logger.debug("Document embeddings query request completed") @@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): type = "document-embeddings-query-error", message = str(e), ), - chunk_ids=[], + chunks=[], ) await flow("response").send(r, properties={"id": id}) diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py index 07eb2bc7..fec82378 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_client.py @@ -19,12 +19,12 @@ def to_value(x): return Literal(x.value or x.iri) class GraphEmbeddingsClient(RequestResponse): - async def query(self, vectors, limit=20, user="trustgraph", + async def query(self, vector, limit=20, user="trustgraph", collection="default", timeout=30): resp = await self.request( GraphEmbeddingsRequest( - vectors = vectors, + vector = vector, limit = limit, user = user, collection = collection @@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - return [ - to_value(v) - for v in resp.entities - ] + # Return EntityMatch objects with entity and score + return resp.entities class GraphEmbeddingsClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py index 0141da31..811adf40 100644 --- a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py +++ b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py @@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse class RowEmbeddingsQueryClient(RequestResponse): async def row_embeddings_query( - self, vectors, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, user="trustgraph", collection="default", index_name=None, limit=10, timeout=600 ): request = RowEmbeddingsRequest( - vectors=vectors, + vector=vector, schema_name=schema_name, user=user, collection=collection, diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index 124cf3c8..1ab47aab 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient): ) def request( - self, vectors, user="trustgraph", collection="default", + self, vector, user="trustgraph", collection="default", limit=10, timeout=300 ): return self.call( user=user, collection=collection, - vectors=vectors, limit=limit, timeout=timeout + vector=vector, limit=limit, timeout=timeout ).chunks diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index 1a7a9512..f85c91ee 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient): ) def request( - self, vectors, user="trustgraph", collection="default", + self, vector, user="trustgraph", collection="default", limit=10, timeout=300 ): return self.call( user=user, collection=collection, - vectors=vectors, limit=limit, timeout=timeout + vector=vector, limit=limit, timeout=timeout ).entities diff --git a/trustgraph-base/trustgraph/clients/row_embeddings_client.py b/trustgraph-base/trustgraph/clients/row_embeddings_client.py index 4f911e3c..19d4b338 100644 --- a/trustgraph-base/trustgraph/clients/row_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/row_embeddings_client.py @@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient): ) def request( - self, vectors, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, user="trustgraph", collection="default", index_name=None, limit=10, timeout=300 ): kwargs = dict( user=user, collection=collection, - vectors=vectors, schema_name=schema_name, + vector=vector, schema_name=schema_name, limit=limit, timeout=timeout ) if index_name: diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index cc5f1534..f10ca4c6 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -10,18 +10,18 @@ from .primitives import ValueTranslator class DocumentEmbeddingsRequestTranslator(MessageTranslator): """Translator for DocumentEmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: return DocumentEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) - + def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: return { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection @@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): class DocumentEmbeddingsResponseTranslator(MessageTranslator): """Translator for DocumentEmbeddingsResponse schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: result = {} - if obj.chunk_ids is not None: - result["chunk_ids"] = list(obj.chunk_ids) + if obj.chunks is not None: + result["chunks"] = [ + { + "chunk_id": chunk.chunk_id, + "score": chunk.score + } + for chunk in obj.chunks + ] return result - + def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" return self.from_pulsar(obj), True @@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator): class GraphEmbeddingsRequestTranslator(MessageTranslator): """Translator for GraphEmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: return GraphEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) - + def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: return { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection @@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): class GraphEmbeddingsResponseTranslator(MessageTranslator): """Translator for GraphEmbeddingsResponse schema objects""" - + def __init__(self): self.value_translator = ValueTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: result = {} - + if obj.entities is not None: result["entities"] = [ - self.value_translator.from_pulsar(entity) - for entity in obj.entities + { + "entity": self.value_translator.from_pulsar(match.entity), + "score": match.score + } + for match in obj.entities ] - + return result - + def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" return self.from_pulsar(obj), True @@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: return RowEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), @@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: result = { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection, diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index c7d5b775..a8bae35c 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -11,7 +11,7 @@ from ..core.topic import topic @dataclass class EntityEmbeddings: entity: Term | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) # Provenance: which chunk this embedding was derived from chunk_id: str = "" @@ -28,7 +28,7 @@ class GraphEmbeddings: @dataclass class ChunkEmbeddings: chunk_id: str = "" - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) # This is a 'batching' mechanism for the above data @dataclass @@ -44,7 +44,7 @@ class DocumentEmbeddings: @dataclass class ObjectEmbeddings: metadata: Metadata | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) name: str = "" key_name: str = "" id: str = "" @@ -56,7 +56,7 @@ class ObjectEmbeddings: @dataclass class StructuredObjectEmbedding: metadata: Metadata | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) schema_name: str = "" object_id: str = "" # Primary key value field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings @@ -72,7 +72,7 @@ class RowIndexEmbedding: index_name: str = "" # The indexed field name(s) index_value: list[str] = field(default_factory=list) # The field value(s) text: str = "" # Text that was embedded - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) @dataclass class RowEmbeddings: diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index a9d19e51..681638c3 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -34,7 +34,7 @@ class EmbeddingsRequest: @dataclass class EmbeddingsResponse: error: Error | None = None - vectors: list[list[list[float]]] = field(default_factory=list) + vectors: list[list[float]] = field(default_factory=list) ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 68857e07..67caa2be 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -9,15 +9,21 @@ from ..core.topic import topic @dataclass class GraphEmbeddingsRequest: - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) limit: int = 0 user: str = "" collection: str = "" +@dataclass +class EntityMatch: + """A matching entity from a semantic search with similarity score""" + entity: Term | None = None + score: float = 0.0 + @dataclass class GraphEmbeddingsResponse: error: Error | None = None - entities: list[Term] = field(default_factory=list) + entities: list[EntityMatch] = field(default_factory=list) ############################################################################ @@ -44,15 +50,21 @@ class TriplesQueryResponse: @dataclass class DocumentEmbeddingsRequest: - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) limit: int = 0 user: str = "" collection: str = "" +@dataclass +class ChunkMatch: + """A matching chunk from a semantic search with similarity score""" + chunk_id: str = "" + score: float = 0.0 + @dataclass class DocumentEmbeddingsResponse: error: Error | None = None - chunk_ids: list[str] = field(default_factory=list) + chunks: list[ChunkMatch] = field(default_factory=list) document_embeddings_request_queue = topic( "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' @@ -76,7 +88,7 @@ class RowIndexMatch: @dataclass class RowEmbeddingsRequest: """Request for row embeddings semantic search""" - vectors: list[list[float]] = field(default_factory=list) # Query vectors + vector: list[float] = field(default_factory=list) # Query vector limit: int = 10 # Max results to return user: str = "" # User/keyspace collection: str = "" # Collection name diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 71fe7409..441c8f38 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl: query_text = arguments.get("query") all_vectors = await embeddings_client.embed([query_text]) - vectors = all_vectors[0] if all_vectors else [] + vector = all_vectors[0] if all_vectors else [] # Now query row embeddings client = self.context("row-embeddings-query-request") @@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl: user = getattr(client, '_current_user', self.user or "trustgraph") matches = await client.row_embeddings_query( - vectors=vectors, + vector=vector, schema_name=self.schema_name, user=user, collection=self.collection or "default", diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py index f038a9b5..16ca1ad9 100755 --- a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -66,13 +66,13 @@ class Processor(FlowProcessor): ) ) - # vectors[0] is the vector set for the first (only) text - vectors = resp.vectors[0] if resp.vectors else [] + # vectors[0] is the vector for the first (only) text + vector = resp.vectors[0] if resp.vectors else [] embeds = [ ChunkEmbeddings( chunk_id=v.document_id, - vectors=vectors, + vector=vector, ) ] diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index ac2c6f49..1a03ac9f 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -59,11 +59,8 @@ class Processor(EmbeddingsService): # FastEmbed processes the full batch efficiently vecs = list(self.embeddings.embed(texts)) - # Return list of vector sets, one per input text - return [ - [v.tolist()] - for v in vecs - ] + # Return list of vectors, one per input text + return [v.tolist() for v in vecs] @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py index e83d608b..3b441bd6 100755 --- a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -72,10 +72,10 @@ class Processor(FlowProcessor): entities = [ EntityEmbeddings( entity=entity.entity, - vectors=vectors, # Vector set for this entity + vector=vector, chunk_id=entity.chunk_id, # Provenance: source chunk ) - for entity, vectors in zip(v.entities, all_vectors) + for entity, vector in zip(v.entities, all_vectors) ] # Send in batches to avoid oversized messages diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index c95850e2..a65b4ff7 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -43,11 +43,8 @@ class Processor(EmbeddingsService): input = texts ) - # Return list of vector sets, one per input text - return [ - [embedding] - for embedding in embeds.embeddings - ] + # Return list of vectors, one per input text + return list(embeds.embeddings) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index f81e4374..1365cb14 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): all_vectors = await flow("embeddings-request").embed(texts=texts) # Pair results with metadata - for text, (index_name, index_value), vectors in zip( + for text, (index_name, index_value), vector in zip( texts, metadata, all_vectors ): embeddings_list.append( @@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): index_name=index_name, index_value=index_value, text=text, - vectors=vectors # Vector set for this text + vector=vector ) ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 6d897b71..98350961 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -7,7 +7,7 @@ of chunk_ids import logging from .... direct.milvus_doc_embeddings import DocVectors -from .... schema import DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error from .... base import DocumentEmbeddingsQueryService @@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] - chunk_ids = [] + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit + ) - for vec in msg.vectors: + chunks = [] + for r in resp: + chunk_id = r["entity"]["chunk_id"] + # Milvus returns distance, convert to similarity score + distance = r.get("distance", 0.0) + score = 1.0 - distance if distance else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) - resp = self.vecstore.search( - vec, - msg.user, - msg.collection, - limit=msg.limit - ) - - for r in resp: - chunk_id = r["entity"]["chunk_id"] - chunk_ids.append(chunk_id) - - return chunk_ids + return chunks except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 41857ab0..406f979c 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -11,6 +11,7 @@ import os from pinecone import Pinecone, ServerlessSpec from pinecone.grpc import PineconeGRPC, GRPCClientConfig +from .... schema import ChunkMatch from .... base import DocumentEmbeddingsQueryService # Module logger @@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] - chunk_ids = [] + dim = len(vec) - for vec in msg.vectors: + # Use dimension suffix in index name + index_name = f"d-{msg.user}-{msg.collection}-{dim}" - dim = len(vec) + # Check if index exists - return empty if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist") + return [] - # Use dimension suffix in index name - index_name = f"d-{msg.user}-{msg.collection}-{dim}" + index = self.pinecone.Index(index_name) - # Check if index exists - skip if not - if not self.pinecone.has_index(index_name): - logger.info(f"Index {index_name} does not exist, skipping this vector") - continue + results = index.query( + vector=vec, + top_k=msg.limit, + include_values=False, + include_metadata=True + ) - index = self.pinecone.Index(index_name) + chunks = [] + for r in results.matches: + chunk_id = r.metadata["chunk_id"] + score = r.score if hasattr(r, 'score') else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) - results = index.query( - vector=vec, - top_k=msg.limit, - include_values=False, - include_metadata=True - ) - - for r in results.matches: - chunk_id = r.metadata["chunk_id"] - chunk_ids.append(chunk_id) - - return chunk_ids + return chunks except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 562023c7..f056b1c1 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -10,7 +10,7 @@ from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -from .... schema import DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error from .... base import DocumentEmbeddingsQueryService @@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService): try: - chunk_ids = [] + vec = msg.vector + if not vec: + return [] - for vec in msg.vectors: + # Use dimension suffix in collection name + dim = len(vec) + collection = f"d_{msg.user}_{msg.collection}_{dim}" - # Use dimension suffix in collection name - dim = len(vec) - collection = f"d_{msg.user}_{msg.collection}_{dim}" + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist, returning empty results") + return [] - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, returning empty results") - continue + search_result = self.qdrant.query_points( + collection_name=collection, + query=vec, + limit=msg.limit, + with_payload=True, + ).points - search_result = self.qdrant.query_points( - collection_name=collection, - query=vec, - limit=msg.limit, - with_payload=True, - ).points + chunks = [] + for r in search_result: + chunk_id = r.payload["chunk_id"] + score = r.score if hasattr(r, 'score') else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) - for r in search_result: - chunk_id = r.payload["chunk_id"] - chunk_ids.append(chunk_id) - - return chunk_ids + return chunks except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index c5cdb6d8..94eee387 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -7,7 +7,7 @@ entities import logging from .... direct.milvus_graph_embeddings import EntityVectors -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService): try: - entity_set = set() - entities = [] + vec = msg.vector + if not vec: + return [] # Handle zero limit case if msg.limit <= 0: return [] - for vec in msg.vectors: + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit * 2 + ) - resp = self.vecstore.search( - vec, - msg.user, - msg.collection, - limit=msg.limit * 2 - ) + entity_set = set() + entities = [] - for r in resp: - ent = r["entity"]["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) + for r in resp: + ent = r["entity"]["entity"] + # Milvus returns distance, convert to similarity score + distance = r.get("distance", 0.0) + score = 1.0 - distance if distance else 0.0 - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break logger.debug("Send response...") return entities diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 5882f21c..ca443a6f 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -11,7 +11,7 @@ import os from pinecone import Pinecone, ServerlessSpec from pinecone.grpc import PineconeGRPC, GRPCClientConfig -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] + dim = len(vec) + + # Use dimension suffix in index name + index_name = f"t-{msg.user}-{msg.collection}-{dim}" + + # Check if index exists - return empty if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist") + return [] + + index = self.pinecone.Index(index_name) + + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) unique entities + results = index.query( + vector=vec, + top_k=msg.limit * 2, + include_values=False, + include_metadata=True + ) + entity_set = set() entities = [] - for vec in msg.vectors: + for r in results.matches: + ent = r.metadata["entity"] + score = r.score if hasattr(r, 'score') else 0.0 - dim = len(vec) - - # Use dimension suffix in index name - index_name = f"t-{msg.user}-{msg.collection}-{dim}" - - # Check if index exists - skip if not - if not self.pinecone.has_index(index_name): - logger.info(f"Index {index_name} does not exist, skipping this vector") - continue - - index = self.pinecone.Index(index_name) - - # Heuristic hack, get (2*limit), so that we have more chance - # of getting (limit) entities - results = index.query( - vector=vec, - top_k=msg.limit * 2, - include_values=False, - include_metadata=True - ) - - for r in results.matches: - - ent = r.metadata["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) - - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break return entities diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index a76059ef..df93ad8b 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -10,7 +10,7 @@ from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + + # Use dimension suffix in collection name + dim = len(vec) + collection = f"t_{msg.user}_{msg.collection}_{dim}" + + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist") + return [] + + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) unique entities + search_result = self.qdrant.query_points( + collection_name=collection, + query=vec, + limit=msg.limit * 2, + with_payload=True, + ).points + entity_set = set() entities = [] - for vec in msg.vectors: + for r in search_result: + ent = r.payload["entity"] + score = r.score if hasattr(r, 'score') else 0.0 - # Use dimension suffix in collection name - dim = len(vec) - collection = f"t_{msg.user}_{msg.collection}_{dim}" - - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, skipping this vector") - continue - - # Heuristic hack, get (2*limit), so that we have more chance - # of getting (limit) entities - search_result = self.qdrant.query_points( - collection_name=collection, - query=vec, - limit=msg.limit * 2, - with_payload=True, - ).points - - for r in search_result: - ent = r.payload["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) - - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break logger.debug("Send response...") return entities diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 7ed6192f..307899d6 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -93,7 +93,9 @@ class Processor(FlowProcessor): async def query_row_embeddings(self, request: RowEmbeddingsRequest): """Execute row embeddings query""" - matches = [] + vec = request.vector + if not vec: + return [] # Find the collection for this user/collection/schema qdrant_collection = self.find_collection( @@ -105,47 +107,47 @@ class Processor(FlowProcessor): f"No Qdrant collection found for " f"{request.user}/{request.collection}/{request.schema_name}" ) + return [] + + try: + # Build optional filter for index_name + query_filter = None + if request.index_name: + query_filter = Filter( + must=[ + FieldCondition( + key="index_name", + match=MatchValue(value=request.index_name) + ) + ] + ) + + # Query Qdrant + search_result = self.qdrant.query_points( + collection_name=qdrant_collection, + query=vec, + limit=request.limit, + with_payload=True, + query_filter=query_filter, + ).points + + # Convert to RowIndexMatch objects + matches = [] + for point in search_result: + payload = point.payload or {} + match = RowIndexMatch( + index_name=payload.get("index_name", ""), + index_value=payload.get("index_value", []), + text=payload.get("text", ""), + score=point.score if hasattr(point, 'score') else 0.0 + ) + matches.append(match) + return matches - for vec in request.vectors: - try: - # Build optional filter for index_name - query_filter = None - if request.index_name: - query_filter = Filter( - must=[ - FieldCondition( - key="index_name", - match=MatchValue(value=request.index_name) - ) - ] - ) - - # Query Qdrant - search_result = self.qdrant.query_points( - collection_name=qdrant_collection, - query=vec, - limit=request.limit, - with_payload=True, - query_filter=query_filter, - ).points - - # Convert to RowIndexMatch objects - for point in search_result: - payload = point.payload or {} - match = RowIndexMatch( - index_name=payload.get("index_name", ""), - index_value=payload.get("index_value", []), - text=payload.get("text", ""), - score=point.score if hasattr(point, 'score') else 0.0 - ) - matches.append(match) - - except Exception as e: - logger.error(f"Failed to query Qdrant: {e}", exc_info=True) - raise - - return matches + except Exception as e: + logger.error(f"Failed to query Qdrant: {e}", exc_info=True) + raise async def on_message(self, msg, consumer, flow): """Handle incoming query request""" diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 6402010a..5e77f733 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -37,26 +37,26 @@ class Query: vectors = await self.get_vector(query) if self.verbose: - logger.debug("Getting chunk_ids from embeddings store...") + logger.debug("Getting chunks from embeddings store...") - # Get chunk_ids from embeddings store - chunk_ids = await self.rag.doc_embeddings_client.query( - vectors, limit=self.doc_limit, + # Get chunk matches from embeddings store + chunk_matches = await self.rag.doc_embeddings_client.query( + vector=vectors, limit=self.doc_limit, user=self.user, collection=self.collection, ) if self.verbose: - logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...") + logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...") # Fetch chunk content from Garage docs = [] - for chunk_id in chunk_ids: - if chunk_id: + for match in chunk_matches: + if match.chunk_id: try: - content = await self.rag.fetch_chunk(chunk_id, self.user) + content = await self.rag.fetch_chunk(match.chunk_id, self.user) docs.append(content) except Exception as e: - logger.warning(f"Failed to fetch chunk {chunk_id}: {e}") + logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}") if self.verbose: logger.debug("Documents fetched:") diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 21d5aed1..2bf6b2ea 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -87,14 +87,14 @@ class Query: if self.verbose: logger.debug("Getting entities...") - entities = await self.rag.graph_embeddings_client.query( - vectors=vectors, limit=self.entity_limit, + entity_matches = await self.rag.graph_embeddings_client.query( + vector=vectors, limit=self.entity_limit, user=self.user, collection=self.collection, ) entities = [ - str(e) - for e in entities + str(e.entity) + for e in entity_matches ] if self.verbose: diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index a4ff0838..e282f876 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if chunk_id == "": continue - for vec in emb.vectors: + vec = emb.vector + if vec: self.vecstore.insert( vec, chunk_id, message.metadata.user, diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index f6393053..ea091d35 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if chunk_id == "": continue - for vec in emb.vectors: + vec = emb.vector + if not vec: + continue - # Create index name with dimension suffix for lazy creation - dim = len(vec) - index_name = ( - f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" - ) + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) - # Lazily create index if it doesn't exist (but only if authorized in config) - if not self.pinecone.has_index(index_name): - logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") - self.create_index(index_name, dim) + # Lazily create index if it doesn't exist (but only if authorized in config) + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) - index = self.pinecone.Index(index_name) + index = self.pinecone.Index(index_name) - # Generate unique ID for each vector - vector_id = str(uuid.uuid4()) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) - records = [ - { - "id": vector_id, - "values": vec, - "metadata": { "chunk_id": chunk_id }, - } - ] + records = [ + { + "id": vector_id, + "values": vec, + "metadata": { "chunk_id": chunk_id }, + } + ] - index.upsert( - vectors = records, - ) + index.upsert( + vectors = records, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 21ea9a98..a87f2128 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if chunk_id == "": continue - for vec in emb.vectors: + vec = emb.vector + if not vec: + continue - # Create collection name with dimension suffix for lazy creation - dim = len(vec) - collection = ( - f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" - ) + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" + ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) - - self.qdrant.upsert( + # Lazily create collection if it doesn't exist (but only if authorized in config) + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "chunk_id": chunk_id, - } - ) - ] + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) ) + self.qdrant.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload={ + "chunk_id": chunk_id, + } + ) + ] + ) + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 8e1c4485..0f27adf9 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): entity_value = get_term_value(entity.entity) if entity_value != "" and entity_value is not None: - for vec in entity.vectors: + vec = entity.vector + if vec: self.vecstore.insert( vec, entity_value, message.metadata.user, diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index f4de7f82..d907e873 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity_value == "" or entity_value is None: continue - for vec in entity.vectors: + vec = entity.vector + if not vec: + continue - # Create index name with dimension suffix for lazy creation - dim = len(vec) - index_name = ( - f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" - ) + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) - # Lazily create index if it doesn't exist (but only if authorized in config) - if not self.pinecone.has_index(index_name): - logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") - self.create_index(index_name, dim) + # Lazily create index if it doesn't exist (but only if authorized in config) + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) - index = self.pinecone.Index(index_name) + index = self.pinecone.Index(index_name) - # Generate unique ID for each vector - vector_id = str(uuid.uuid4()) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) - metadata = {"entity": entity_value} - if entity.chunk_id: - metadata["chunk_id"] = entity.chunk_id + metadata = {"entity": entity_value} + if entity.chunk_id: + metadata["chunk_id"] = entity.chunk_id - records = [ - { - "id": vector_id, - "values": vec, - "metadata": metadata, - } - ] + records = [ + { + "id": vector_id, + "values": vec, + "metadata": metadata, + } + ] - index.upsert( - vectors = records, - ) + index.upsert( + vectors = records, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 4877ae96..f887d487 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity_value == "" or entity_value is None: continue - for vec in entity.vectors: + vec = entity.vector + if not vec: + continue - # Create collection name with dimension suffix for lazy creation - dim = len(vec) - collection = ( - f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" - ) + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" + ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) - - payload = { - "entity": entity_value, - } - if entity.chunk_id: - payload["chunk_id"] = entity.chunk_id - - self.qdrant.upsert( + # Lazily create collection if it doesn't exist (but only if authorized in config) + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload=payload, - ) - ] + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) ) + payload = { + "entity": entity_value, + } + if entity.chunk_id: + payload["chunk_id"] = entity.chunk_id + + self.qdrant.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload=payload, + ) + ] + ) + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 29848c4c..42e59012 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor): qdrant_collection = None for row_emb in embeddings.embeddings: - if not row_emb.vectors: + vector = row_emb.vector + if not vector: logger.warning( - f"No vectors for index {row_emb.index_name} - skipping" + f"No vector for index {row_emb.index_name} - skipping" ) continue - # Use first vector (there may be multiple from different models) - for vector in row_emb.vectors: - dimension = len(vector) + dimension = len(vector) - # Create/get collection name (lazily on first vector) - if qdrant_collection is None: - qdrant_collection = self.get_collection_name( - user, collection, schema_name, dimension - ) - self.ensure_collection(qdrant_collection, dimension) - - # Write to Qdrant - self.qdrant.upsert( - collection_name=qdrant_collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vector, - payload={ - "index_name": row_emb.index_name, - "index_value": row_emb.index_value, - "text": row_emb.text - } - ) - ] + # Create/get collection name (lazily on first vector) + if qdrant_collection is None: + qdrant_collection = self.get_collection_name( + user, collection, schema_name, dimension ) - embeddings_written += 1 + self.ensure_collection(qdrant_collection, dimension) + + # Write to Qdrant + self.qdrant.upsert( + collection_name=qdrant_collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "index_name": row_emb.index_name, + "index_value": row_emb.index_value, + "text": row_emb.text + } + ) + ] + ) + embeddings_written += 1 logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")