Fixing tests

This commit is contained in:
Cyber MacGeddon 2026-03-09 10:30:46 +00:00
parent b7894c7088
commit f261604991
2 changed files with 54 additions and 62 deletions

View file

@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify results are converted to Term objects
# Verify results are converted to EntityMatch objects
assert len(result) == 3
assert isinstance(result[0], Term)
assert result[0].iri == "http://example.com/entity1"
assert result[0].type == IRI
assert isinstance(result[1], Term)
assert result[1].iri == "http://example.com/entity2"
assert result[1].type == IRI
assert isinstance(result[2], Term)
assert result[2].value == "literal entity"
assert result[2].type == LITERAL
assert isinstance(result[0], EntityMatch)
assert result[0].entity.iri == "http://example.com/entity1"
assert result[0].entity.type == IRI
assert isinstance(result[1], EntityMatch)
assert result[1].entity.iri == "http://example.com/entity2"
assert result[1].entity.type == IRI
assert isinstance(result[2], EntityMatch)
assert result[2].entity.value == "literal entity"
assert result[2].entity.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify results are deduplicated and limited
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
@ -344,18 +344,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify all results are properly typed
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.type == IRI]
uri_results = [r for r in result if r.entity.type == IRI]
assert len(uri_results) == 2
uri_values = [r.iri for r in uri_results]
uri_values = [r.entity.iri for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.type == IRI]
literal_results = [r for r in result if not r.entity.type == IRI]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
literal_values = [r.entity.value for r in literal_results]
assert "literal entity text" in literal_values
assert "another literal" in literal_values
@ -483,6 +483,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify results
assert len(result) == 2
entity_values = [r.iri if r.type == IRI else r.value for r in result]
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values

View file

@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor
from trustgraph.schema import IRI, LITERAL
from trustgraph.schema import IRI, LITERAL, EntityMatch
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected entities
# Verify result contains expected EntityMatch objects
assert len(result) == 2
assert all(hasattr(entity, 'value') for entity in result)
entity_values = [entity.value for entity in result]
assert all(isinstance(entity, EntityMatch) for entity in result)
entity_values = [entity.entity.value for entity in result]
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify deduplication - entity2 appears in both results but should only appear once
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert len(set(entity_values)) == len(entity_values) # All unique
assert 'entity1' in entity_values
assert 'entity2' in entity_values
assert 'entity3' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2] # 2D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
# Call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert 'entity2d' in entity_values
assert 'entity3d' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'uri_user'
mock_message.collection = 'uri_collection'
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 3
# Check URI entities
uri_entities = [entity for entity in result if entity.type == IRI]
uri_entities = [entity for entity in result if entity.entity.type == IRI]
assert len(uri_entities) == 2
uri_values = [entity.iri for entity in uri_entities]
uri_values = [entity.entity.iri for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities
regular_entities = [entity for entity in result if entity.type == LITERAL]
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity'
assert regular_entities[0].entity.value == 'regular entity'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# With zero limit, the logic still adds one entity before checking the limit
# So it returns one result (current behavior, not ideal but actual)
assert len(result) == 1
assert result[0].value == 'entity1'
assert result[0].entity.value == 'entity1'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')