mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-03 12:22:37 +02:00
Fixing tests
This commit is contained in:
parent
b7894c7088
commit
f261604991
2 changed files with 54 additions and 62 deletions
|
|
@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
|
||||
)
|
||||
|
||||
# Verify results are converted to Term objects
|
||||
# Verify results are converted to EntityMatch objects
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], Term)
|
||||
assert result[0].iri == "http://example.com/entity1"
|
||||
assert result[0].type == IRI
|
||||
assert isinstance(result[1], Term)
|
||||
assert result[1].iri == "http://example.com/entity2"
|
||||
assert result[1].type == IRI
|
||||
assert isinstance(result[2], Term)
|
||||
assert result[2].value == "literal entity"
|
||||
assert result[2].type == LITERAL
|
||||
assert isinstance(result[0], EntityMatch)
|
||||
assert result[0].entity.iri == "http://example.com/entity1"
|
||||
assert result[0].entity.type == IRI
|
||||
assert isinstance(result[1], EntityMatch)
|
||||
assert result[1].entity.iri == "http://example.com/entity2"
|
||||
assert result[1].entity.type == IRI
|
||||
assert isinstance(result[2], EntityMatch)
|
||||
assert result[2].entity.value == "literal entity"
|
||||
assert result[2].entity.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
||||
|
|
@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify results are deduplicated and limited
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
|
@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify duplicates are removed
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
|
|
@ -344,18 +344,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
||||
|
||||
# Check URI entities
|
||||
uri_results = [r for r in result if r.type == IRI]
|
||||
uri_results = [r for r in result if r.entity.type == IRI]
|
||||
assert len(uri_results) == 2
|
||||
uri_values = [r.iri for r in uri_results]
|
||||
uri_values = [r.entity.iri for r in uri_results]
|
||||
assert "http://example.com/uri_entity" in uri_values
|
||||
assert "https://example.com/another_uri" in uri_values
|
||||
|
||||
|
||||
# Check literal entities
|
||||
literal_results = [r for r in result if not r.type == IRI]
|
||||
literal_results = [r for r in result if not r.entity.type == IRI]
|
||||
assert len(literal_results) == 2
|
||||
literal_values = [r.value for r in literal_results]
|
||||
literal_values = [r.entity.value for r in literal_results]
|
||||
assert "literal entity text" in literal_values
|
||||
assert "another literal" in literal_values
|
||||
|
||||
|
|
@ -483,6 +483,6 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
|
|
@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
from trustgraph.schema import IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
|
|
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
|
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected entities
|
||||
# Verify result contains expected EntityMatch objects
|
||||
assert len(result) == 2
|
||||
assert all(hasattr(entity, 'value') for entity in result)
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert all(isinstance(entity, EntityMatch) for entity in result)
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
||||
|
|
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
|
||||
# Verify both collections were queried (both 2-dimensional vectors)
|
||||
# Verify collection was queried
|
||||
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
||||
entity_values = [entity.value for entity in result]
|
||||
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
|
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
|
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.vector = [0.1, 0.2] # 2D vector
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
# Call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
entity_values = [entity.value for entity in result]
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert 'entity2d' in entity_values
|
||||
assert 'entity3d' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'uri_user'
|
||||
mock_message.collection = 'uri_collection'
|
||||
|
|
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if entity.type == IRI]
|
||||
uri_entities = [entity for entity in result if entity.entity.type == IRI]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.iri for entity in uri_entities]
|
||||
uri_values = [entity.entity.iri for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if entity.type == LITERAL]
|
||||
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
assert regular_entities[0].entity.value == 'regular entity'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
|
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
|
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# With zero limit, the logic still adds one entity before checking the limit
|
||||
# So it returns one result (current behavior, not ideal but actual)
|
||||
assert len(result) == 1
|
||||
assert result[0].value == 'entity1'
|
||||
assert result[0].entity.value == 'entity1'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue