Fixing tests

This commit is contained in:
Cyber MacGeddon 2026-03-09 10:17:48 +00:00
parent df12467510
commit 7bcced37f2
4 changed files with 107 additions and 127 deletions

View file

@ -131,7 +131,7 @@ class TestGraphRagIntegration:
# 2. Should query graph embeddings to find relevant entities # 2. Should query graph embeddings to find relevant entities
mock_graph_embeddings_client.query.assert_called_once() mock_graph_embeddings_client.query.assert_called_once()
call_args = mock_graph_embeddings_client.query.call_args call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vector'] == [0.1, 0.2, 0.3, 0.4, 0.5] assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection assert call_args.kwargs['collection'] == collection

View file

@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = ["test_chunk"] mock_response.chunks = ["test_chunk"]
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3]] vector = [0.1, 0.2, 0.3]
# Act # Act
result = client.request(vectors=vectors) result = client.request(vector=vector)
# Assert # Assert
assert result == ["test_chunk"] assert result == ["test_chunk"]
client.call.assert_called_once_with( client.call.assert_called_once_with(
user="trustgraph", user="trustgraph",
collection="default", collection="default",
vectors=vectors, vector=vector,
limit=10, limit=10,
timeout=300 timeout=300
) )
@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = [] mock_response.chunks = []
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
# Act # Act
result = client.request(vectors=[[0.1, 0.2, 0.3]]) result = client.request(vector=[0.1, 0.2, 0.3])
# Assert # Assert
assert result == [] assert result == []
@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = None mock_response.chunks = None
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
# Act # Act
result = client.request(vectors=[[0.1, 0.2, 0.3]]) result = client.request(vector=[0.1, 0.2, 0.3])
# Assert # Assert
assert result is None assert result is None
@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.chunks = ["chunk1"] mock_response.chunks = ["chunk1"]
client.call = MagicMock(return_value=mock_response) client.call = MagicMock(return_value=mock_response)
# Act # Act
client.request( client.request(
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
timeout=600 timeout=600
) )
# Assert # Assert
assert client.call.call_args[1]["timeout"] == 600 assert client.call.call_args[1]["timeout"] == 600

View file

@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10 limit=10
) )
return query return query
@ -71,7 +71,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -97,43 +97,35 @@ class TestMilvusDocEmbeddingsQueryProcessor:
assert result[2] == "Third document chunk" assert result[2] == "Third document chunk"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor): async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with multiple vectors""" """Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3 limit=3
) )
# Mock search results - different results for each vector # Mock search results
mock_results_1 = [ mock_results = [
{"entity": {"chunk_id": "Document from first vector"}}, {"entity": {"chunk_id": "First document"}},
{"entity": {"chunk_id": "Another doc from first vector"}}, {"entity": {"chunk_id": "Second document"}},
{"entity": {"chunk_id": "Third document"}},
] ]
mock_results_2 = [ processor.vecstore.search.return_value = mock_results
{"entity": {"chunk_id": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters including user/collection # Verify search was called once with the full vector
expected_calls = [ processor.vecstore.search.assert_called_once_with(
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}), )
]
assert processor.vecstore.search.call_count == 2 # Verify results
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results from all vectors are combined
assert len(result) == 3 assert len(result) == 3
assert "Document from first vector" in result assert "First document" in result
assert "Another doc from first vector" in result assert "Second document" in result
assert "Document from second vector" in result assert "Third document" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor): async def test_query_document_embeddings_with_limit(self, processor):
@ -141,7 +133,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=2 limit=2
) )
@ -170,7 +162,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[], vector=[],
limit=5 limit=5
) )
@ -188,7 +180,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -211,7 +203,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -237,7 +229,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -262,7 +254,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -288,7 +280,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=0 limit=0
) )
@ -306,7 +298,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=-1 limit=-1
) )
@ -324,7 +316,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -341,59 +333,51 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[ vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5 limit=5
) )
# Mock search results for each vector # Mock search results
mock_results_1 = [{"entity": {"chunk_id": "Document from 2D vector"}}] mock_results = [
mock_results_2 = [{"entity": {"chunk_id": "Document from 4D vector"}}] {"entity": {"chunk_id": "Document 1"}},
mock_results_3 = [{"entity": {"chunk_id": "Document from 3D vector"}}] {"entity": {"chunk_id": "Document 2"}},
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] ]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Verify all vectors were searched # Verify search was called with the vector
assert processor.vecstore.search.call_count == 3 processor.vecstore.search.assert_called_once()
# Verify results from all dimensions # Verify results
assert len(result) == 3 assert len(result) == 2
assert "Document from 2D vector" in result assert "Document 1" in result
assert "Document from 4D vector" in result assert "Document 2" in result
assert "Document from 3D vector" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_document_embeddings_duplicate_documents(self, processor): async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with duplicate documents in results""" """Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest( query = DocumentEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5 limit=5
) )
# Mock search results with duplicates across vectors # Mock search results with multiple documents
mock_results_1 = [ mock_results = [
{"entity": {"chunk_id": "Document A"}}, {"entity": {"chunk_id": "Document A"}},
{"entity": {"chunk_id": "Document B"}}, {"entity": {"chunk_id": "Document B"}},
]
mock_results_2 = [
{"entity": {"chunk_id": "Document B"}}, # Duplicate
{"entity": {"chunk_id": "Document C"}}, {"entity": {"chunk_id": "Document C"}},
] ]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2] processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query) result = await processor.query_document_embeddings(query)
# Note: Unlike graph embeddings, doc embeddings don't deduplicate # Verify results
# This preserves ranking and allows multiple occurrences assert len(result) == 3
assert len(result) == 4
assert result.count("Document B") == 2 # Should appear twice
assert "Document A" in result assert "Document A" in result
assert "Document B" in result
assert "Document C" in result assert "Document C" in result
def test_add_args_method(self): def test_add_args_method(self):

View file

@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10 limit=10
) )
return query return query
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -156,7 +156,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3 limit=3
) )
@ -197,7 +197,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=2 limit=2
) )
@ -226,7 +226,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5 limit=5
) )
@ -258,7 +258,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2 limit=2
) )
@ -286,7 +286,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[], vector=[],
limit=5 limit=5
) )
@ -304,7 +304,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -327,7 +327,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -365,7 +365,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=5 limit=5
) )
@ -447,7 +447,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[[0.1, 0.2, 0.3]], vector=[0.1, 0.2, 0.3],
limit=0 limit=0
) )
@ -460,33 +460,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 0 assert len(result) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor): async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with different vector dimensions""" """Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest( query = GraphEmbeddingsRequest(
user='test_user', user='test_user',
collection='test_collection', collection='test_collection',
vectors=[ vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
limit=5 limit=5
) )
# Mock search results for each vector # Mock search results
mock_results_1 = [{"entity": {"entity": "entity_2d"}}] mock_results = [
mock_results_2 = [{"entity": {"entity": "entity_4d"}}] {"entity": {"entity": "http://example.com/entity1"}},
mock_results_3 = [{"entity": {"entity": "entity_3d"}}] {"entity": {"entity": "http://example.com/entity2"}},
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3] ]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query) result = await processor.query_graph_embeddings(query)
# Verify all vectors were searched # Verify search was called once with the full vector
assert processor.vecstore.search.call_count == 3 processor.vecstore.search.assert_called_once()
# Verify results from all dimensions # Verify results
assert len(result) == 3 assert len(result) == 2
entity_values = [r.iri if r.type == IRI else r.value for r in result] entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "entity_2d" in entity_values assert "http://example.com/entity1" in entity_values
assert "entity_4d" in entity_values assert "http://example.com/entity2" in entity_values
assert "entity_3d" in entity_values