diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index c27895fb..458d613d 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor: [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10 ) - # Verify results are converted to Term objects + # Verify results are converted to EntityMatch objects assert len(result) == 3 - assert isinstance(result[0], Term) - assert result[0].iri == "http://example.com/entity1" - assert result[0].type == IRI - assert isinstance(result[1], Term) - assert result[1].iri == "http://example.com/entity2" - assert result[1].type == IRI - assert isinstance(result[2], Term) - assert result[2].value == "literal entity" - assert result[2].type == LITERAL + assert isinstance(result[0], EntityMatch) + assert result[0].entity.iri == "http://example.com/entity1" + assert result[0].entity.type == IRI + assert isinstance(result[1], EntityMatch) + assert result[1].entity.iri == "http://example.com/entity2" + assert result[1].entity.type == IRI + assert isinstance(result[2], EntityMatch) + assert result[2].entity.value == "literal entity" + assert result[2].entity.type == LITERAL @pytest.mark.asyncio async def test_query_graph_embeddings_multiple_vectors(self, processor): @@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify results are deduplicated and limited assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] + entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] assert "http://example.com/entity1" in entity_values assert "http://example.com/entity2" in entity_values assert "http://example.com/entity3" in entity_values @@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify duplicates are removed assert len(result) == 3 - entity_values = [r.iri if r.type == IRI else r.value for r in result] + entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] assert len(set(entity_values)) == 3 # All unique assert "http://example.com/entity1" in entity_values assert "http://example.com/entity2" in entity_values @@ -344,18 +344,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify all results are properly typed assert len(result) == 4 - + # Check URI entities - uri_results = [r for r in result if r.type == IRI] + uri_results = [r for r in result if r.entity.type == IRI] assert len(uri_results) == 2 - uri_values = [r.iri for r in uri_results] + uri_values = [r.entity.iri for r in uri_results] assert "http://example.com/uri_entity" in uri_values assert "https://example.com/another_uri" in uri_values - + # Check literal entities - literal_results = [r for r in result if not r.type == IRI] + literal_results = [r for r in result if not r.entity.type == IRI] assert len(literal_results) == 2 - literal_values = [r.value for r in literal_results] + literal_values = [r.entity.value for r in literal_results] assert "literal entity text" in literal_values assert "another literal" in literal_values @@ -483,6 +483,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor: # Verify results assert len(result) == 2 - entity_values = [r.iri if r.type == IRI else r.value for r in result] + entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result] assert "http://example.com/entity1" in entity_values assert "http://example.com/entity2" in entity_values \ No newline at end of file diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 1760c4c1..9362a8dd 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase # Import the service under test from trustgraph.query.graph_embeddings.qdrant.service import Processor -from trustgraph.schema import IRI, LITERAL +from trustgraph.schema import IRI, LITERAL, EntityMatch class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): @@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 5 mock_message.user = 'test_user' mock_message.collection = 'test_collection' @@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): with_payload=True ) - # Verify result contains expected entities + # Verify result contains expected EntityMatch objects assert len(result) == 2 - assert all(hasattr(entity, 'value') for entity in result) - entity_values = [entity.value for entity in result] + assert all(isinstance(entity, EntityMatch) for entity in result) + entity_values = [entity.entity.value for entity in result] assert 'entity1' in entity_values assert 'entity2' in entity_values @@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): } processor = Processor(**config) - - # Create mock message with multiple vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 3 mock_message.user = 'multi_user' mock_message.collection = 'multi_collection' - + # Act result = await processor.query_graph_embeddings(mock_message) # Assert - # Verify query was called twice - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 - # Verify both collections were queried (both 2-dimensional vectors) + # Verify collection was queried expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection - assert calls[1][1]['collection_name'] == expected_collection assert calls[0][1]['query'] == [0.1, 0.2] - assert calls[1][1]['query'] == [0.3, 0.4] - - # Verify deduplication - entity2 appears in both results but should only appear once - entity_values = [entity.value for entity in result] + + # Verify results with EntityMatch structure + entity_values = [entity.entity.value for entity in result] assert len(set(entity_values)) == len(entity_values) # All unique assert 'entity1' in entity_values assert 'entity2' in entity_values - assert 'entity3' in entity_values @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2, 0.3]] + mock_message.vector = [0.1, 0.2, 0.3] mock_message.limit = 3 # Should only return 3 results mock_message.user = 'limit_user' mock_message.collection = 'limit_collection' @@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'empty_user' mock_message.collection = 'empty_collection' @@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): } processor = Processor(**config) - - # Create mock message with different dimension vectors + + # Create mock message with single vector mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D + mock_message.vector = [0.1, 0.2] # 2D vector mock_message.limit = 5 mock_message.user = 'dim_user' mock_message.collection = 'dim_collection' - + # Act result = await processor.query_graph_embeddings(mock_message) # Assert - # Verify query was called twice with different collections - assert mock_qdrant_instance.query_points.call_count == 2 + # Verify query was called once + assert mock_qdrant_instance.query_points.call_count == 1 calls = mock_qdrant_instance.query_points.call_args_list - # First call should use 2D collection + # Call should use 2D collection assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions assert calls[0][1]['query'] == [0.1, 0.2] - # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions - assert calls[1][1]['query'] == [0.3, 0.4, 0.5] - - # Verify results - entity_values = [entity.value for entity in result] + # Verify results with EntityMatch structure + entity_values = [entity.entity.value for entity in result] assert 'entity2d' in entity_values - assert 'entity3d' in entity_values @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'uri_user' mock_message.collection = 'uri_collection' @@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert assert len(result) == 3 - + # Check URI entities - uri_entities = [entity for entity in result if entity.type == IRI] + uri_entities = [entity for entity in result if entity.entity.type == IRI] assert len(uri_entities) == 2 - uri_values = [entity.iri for entity in uri_entities] + uri_values = [entity.entity.iri for entity in uri_entities] assert 'http://example.com/entity1' in uri_values assert 'https://secure.example.com/entity2' in uri_values # Check regular entities - regular_entities = [entity for entity in result if entity.type == LITERAL] + regular_entities = [entity for entity in result if entity.entity.type == LITERAL] assert len(regular_entities) == 1 - assert regular_entities[0].value == 'regular entity' + assert regular_entities[0].entity.value == 'regular entity' @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__') @@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 5 mock_message.user = 'error_user' mock_message.collection = 'error_collection' @@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Create mock message with zero limit mock_message = MagicMock() - mock_message.vectors = [[0.1, 0.2]] + mock_message.vector = [0.1, 0.2] mock_message.limit = 0 mock_message.user = 'zero_user' mock_message.collection = 'zero_collection' @@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # With zero limit, the logic still adds one entity before checking the limit # So it returns one result (current behavior, not ideal but actual) assert len(result) == 1 - assert result[0].value == 'entity1' + assert result[0].entity.value == 'entity1' @patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient') @patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')