Fixing tests

This commit is contained in:
Cyber MacGeddon 2026-03-09 10:05:39 +00:00
parent dcee1b8de2
commit 4e3db11323
2 changed files with 64 additions and 45 deletions

View file

@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services
import pytest import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error
from trustgraph.messaging.translators.embeddings_query import ( from trustgraph.messaging.translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsRequestTranslator,
DocumentEmbeddingsResponseTranslator DocumentEmbeddingsResponseTranslator
@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract:
"""Test that DocumentEmbeddingsRequest has expected fields""" """Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request # Create a request
request = DocumentEmbeddingsRequest( request = DocumentEmbeddingsRequest(
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3],
limit=10, limit=10,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
) )
# Verify all expected fields exist # Verify all expected fields exist
assert hasattr(request, 'vectors') assert hasattr(request, 'vector')
assert hasattr(request, 'limit') assert hasattr(request, 'limit')
assert hasattr(request, 'user') assert hasattr(request, 'user')
assert hasattr(request, 'collection') assert hasattr(request, 'collection')
# Verify field values # Verify field values
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] assert request.vector == [0.1, 0.2, 0.3]
assert request.limit == 10 assert request.limit == 10
assert request.user == "test_user" assert request.user == "test_user"
assert request.collection == "test_collection" assert request.collection == "test_collection"
@ -43,7 +43,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator() translator = DocumentEmbeddingsRequestTranslator()
data = { data = {
"vectors": [[0.1, 0.2], [0.3, 0.4]], "vector": [0.1, 0.2, 0.3, 0.4],
"limit": 5, "limit": 5,
"user": "custom_user", "user": "custom_user",
"collection": "custom_collection" "collection": "custom_collection"
@ -52,7 +52,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.to_pulsar(data) result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest) assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] assert result.vector == [0.1, 0.2, 0.3, 0.4]
assert result.limit == 5 assert result.limit == 5
assert result.user == "custom_user" assert result.user == "custom_user"
assert result.collection == "custom_collection" assert result.collection == "custom_collection"
@ -62,14 +62,14 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator() translator = DocumentEmbeddingsRequestTranslator()
data = { data = {
"vectors": [[0.1, 0.2]] "vector": [0.1, 0.2]
# No limit, user, or collection provided # No limit, user, or collection provided
} }
result = translator.to_pulsar(data) result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest) assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2]] assert result.vector == [0.1, 0.2]
assert result.limit == 10 # Default assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default assert result.collection == "default" # Default
@ -79,7 +79,7 @@ class TestDocumentEmbeddingsRequestContract:
translator = DocumentEmbeddingsRequestTranslator() translator = DocumentEmbeddingsRequestTranslator()
request = DocumentEmbeddingsRequest( request = DocumentEmbeddingsRequest(
vectors=[[0.5, 0.6]], vector=[0.5, 0.6],
limit=20, limit=20,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"
@ -88,7 +88,7 @@ class TestDocumentEmbeddingsRequestContract:
result = translator.from_pulsar(request) result = translator.from_pulsar(request)
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["vectors"] == [[0.5, 0.6]] assert result["vector"] == [0.5, 0.6]
assert result["limit"] == 20 assert result["limit"] == 20
assert result["user"] == "test_user" assert result["user"] == "test_user"
assert result["collection"] == "test_collection" assert result["collection"] == "test_collection"
@ -99,19 +99,25 @@ class TestDocumentEmbeddingsResponseContract:
def test_response_schema_fields(self): def test_response_schema_fields(self):
"""Test that DocumentEmbeddingsResponse has expected fields""" """Test that DocumentEmbeddingsResponse has expected fields"""
# Create a response with chunk_ids # Create a response with chunks
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["chunk1", "chunk2", "chunk3"] chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8),
ChunkMatch(chunk_id="chunk3", score=0.7)
]
) )
# Verify all expected fields exist # Verify all expected fields exist
assert hasattr(response, 'error') assert hasattr(response, 'error')
assert hasattr(response, 'chunk_ids') assert hasattr(response, 'chunks')
# Verify field values # Verify field values
assert response.error is None assert response.error is None
assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"] assert len(response.chunks) == 3
assert response.chunks[0].chunk_id == "chunk1"
assert response.chunks[0].score == 0.9
def test_response_schema_with_error(self): def test_response_schema_with_error(self):
"""Test response schema with error""" """Test response schema with error"""
@ -122,53 +128,59 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=error, error=error,
chunk_ids=[] chunks=[]
) )
assert response.error == error assert response.error == error
assert response.chunk_ids == [] assert response.chunks == []
def test_response_translator_from_pulsar_with_chunk_ids(self): def test_response_translator_from_pulsar_with_chunks(self):
"""Test response translator converts Pulsar schema with chunk_ids to dict""" """Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"] chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85),
ChunkMatch(chunk_id="doc3/c3", score=0.75)
]
) )
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" in result assert "chunks" in result
assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"] assert len(result["chunks"]) == 3
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
assert result["chunks"][0]["score"] == 0.95
def test_response_translator_from_pulsar_with_empty_chunk_ids(self): def test_response_translator_from_pulsar_with_empty_chunks(self):
"""Test response translator handles empty chunk_ids list""" """Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=[] chunks=[]
) )
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" in result assert "chunks" in result
assert result["chunk_ids"] == [] assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunk_ids(self): def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunk_ids""" """Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock() response = MagicMock()
response.chunk_ids = None response.chunks = None
result = translator.from_pulsar(response) result = translator.from_pulsar(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" not in result or result.get("chunk_ids") is None assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self): def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag""" """Test response translator with completion flag"""
@ -176,14 +188,18 @@ class TestDocumentEmbeddingsResponseContract:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["chunk1", "chunk2"] chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8)
]
) )
result, is_final = translator.from_response_with_completion(response) result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict) assert isinstance(result, dict)
assert "chunk_ids" in result assert "chunks" in result
assert result["chunk_ids"] == ["chunk1", "chunk2"] assert len(result["chunks"]) == 2
assert result["chunks"][0]["chunk_id"] == "chunk1"
assert is_final is True # Document embeddings responses are always final assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self): def test_response_translator_to_pulsar_not_implemented(self):
@ -191,7 +207,7 @@ class TestDocumentEmbeddingsResponseContract:
translator = DocumentEmbeddingsResponseTranslator() translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunk_ids": ["test"]}) translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
class TestDocumentEmbeddingsMessageCompatibility: class TestDocumentEmbeddingsMessageCompatibility:
@ -201,7 +217,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
"""Test complete request-response flow maintains data integrity""" """Test complete request-response flow maintains data integrity"""
# Create request # Create request
request_data = { request_data = {
"vectors": [[0.1, 0.2, 0.3]], "vector": [0.1, 0.2, 0.3],
"limit": 5, "limit": 5,
"user": "test_user", "user": "test_user",
"collection": "test_collection" "collection": "test_collection"
@ -214,7 +230,10 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Simulate service processing and creating response # Simulate service processing and creating response
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=None, error=None,
chunk_ids=["doc1/c1", "doc2/c2"] chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85)
]
) )
# Convert response back to dict # Convert response back to dict
@ -224,8 +243,8 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify data integrity # Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest) assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict) assert isinstance(response_data, dict)
assert "chunk_ids" in response_data assert "chunks" in response_data
assert len(response_data["chunk_ids"]) == 2 assert len(response_data["chunks"]) == 2
def test_error_response_flow(self): def test_error_response_flow(self):
"""Test error response flow""" """Test error response flow"""
@ -237,7 +256,7 @@ class TestDocumentEmbeddingsMessageCompatibility:
response = DocumentEmbeddingsResponse( response = DocumentEmbeddingsResponse(
error=error, error=error,
chunk_ids=[] chunks=[]
) )
# Convert response to dict # Convert response to dict
@ -246,6 +265,6 @@ class TestDocumentEmbeddingsMessageCompatibility:
# Verify error handling # Verify error handling
assert isinstance(response_data, dict) assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunk_ids # The translator doesn't include error in the dict, only chunks
assert "chunk_ids" in response_data assert "chunks" in response_data
assert response_data["chunk_ids"] == [] assert response_data["chunks"] == []

View file

@ -289,7 +289,7 @@ class TestStructuredEmbeddingsContracts:
# Act # Act
embedding = StructuredObjectEmbedding( embedding = StructuredObjectEmbedding(
metadata=metadata, metadata=metadata,
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], vector=[0.1, 0.2, 0.3],
schema_name="customer_records", schema_name="customer_records",
object_id="customer_123", object_id="customer_123",
field_embeddings={ field_embeddings={
@ -301,7 +301,7 @@ class TestStructuredEmbeddingsContracts:
# Assert # Assert
assert embedding.schema_name == "customer_records" assert embedding.schema_name == "customer_records"
assert embedding.object_id == "customer_123" assert embedding.object_id == "customer_123"
assert len(embedding.vectors) == 2 assert len(embedding.vector) == 3
assert len(embedding.field_embeddings) == 2 assert len(embedding.field_embeddings) == 2
assert "name" in embedding.field_embeddings assert "name" in embedding.field_embeddings