diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py index c35dde4b..c7d6369a 100644 --- a/tests/contract/test_document_embeddings_contract.py +++ b/tests/contract/test_document_embeddings_contract.py @@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services import pytest from unittest.mock import MagicMock -from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error +from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error from trustgraph.messaging.translators.embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator @@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract: """Test that DocumentEmbeddingsRequest has expected fields""" # Create a request request = DocumentEmbeddingsRequest( - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3], limit=10, user="test_user", collection="test_collection" ) # Verify all expected fields exist - assert hasattr(request, 'vectors') + assert hasattr(request, 'vector') assert hasattr(request, 'limit') assert hasattr(request, 'user') assert hasattr(request, 'collection') # Verify field values - assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + assert request.vector == [0.1, 0.2, 0.3] assert request.limit == 10 assert request.user == "test_user" assert request.collection == "test_collection" @@ -43,7 +43,7 @@ class TestDocumentEmbeddingsRequestContract: translator = DocumentEmbeddingsRequestTranslator() data = { - "vectors": [[0.1, 0.2], [0.3, 0.4]], + "vector": [0.1, 0.2, 0.3, 0.4], "limit": 5, "user": "custom_user", "collection": "custom_collection" @@ -52,7 +52,7 @@ class TestDocumentEmbeddingsRequestContract: result = translator.to_pulsar(data) assert isinstance(result, DocumentEmbeddingsRequest) - assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] + assert result.vector == [0.1, 0.2, 0.3, 0.4] assert result.limit == 5 assert result.user == "custom_user" assert result.collection == "custom_collection" @@ -62,14 +62,14 @@ class TestDocumentEmbeddingsRequestContract: translator = DocumentEmbeddingsRequestTranslator() data = { - "vectors": [[0.1, 0.2]] + "vector": [0.1, 0.2] # No limit, user, or collection provided } result = translator.to_pulsar(data) assert isinstance(result, DocumentEmbeddingsRequest) - assert result.vectors == [[0.1, 0.2]] + assert result.vector == [0.1, 0.2] assert result.limit == 10 # Default assert result.user == "trustgraph" # Default assert result.collection == "default" # Default @@ -79,7 +79,7 @@ class TestDocumentEmbeddingsRequestContract: translator = DocumentEmbeddingsRequestTranslator() request = DocumentEmbeddingsRequest( - vectors=[[0.5, 0.6]], + vector=[0.5, 0.6], limit=20, user="test_user", collection="test_collection" @@ -88,7 +88,7 @@ class TestDocumentEmbeddingsRequestContract: result = translator.from_pulsar(request) assert isinstance(result, dict) - assert result["vectors"] == [[0.5, 0.6]] + assert result["vector"] == [0.5, 0.6] assert result["limit"] == 20 assert result["user"] == "test_user" assert result["collection"] == "test_collection" @@ -99,19 +99,25 @@ class TestDocumentEmbeddingsResponseContract: def test_response_schema_fields(self): """Test that DocumentEmbeddingsResponse has expected fields""" - # Create a response with chunk_ids + # Create a response with chunks response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["chunk1", "chunk2", "chunk3"] + chunks=[ + ChunkMatch(chunk_id="chunk1", score=0.9), + ChunkMatch(chunk_id="chunk2", score=0.8), + ChunkMatch(chunk_id="chunk3", score=0.7) + ] ) # Verify all expected fields exist assert hasattr(response, 'error') - assert hasattr(response, 'chunk_ids') + assert hasattr(response, 'chunks') # Verify field values assert response.error is None - assert response.chunk_ids == ["chunk1", "chunk2", "chunk3"] + assert len(response.chunks) == 3 + assert response.chunks[0].chunk_id == "chunk1" + assert response.chunks[0].score == 0.9 def test_response_schema_with_error(self): """Test response schema with error""" @@ -122,53 +128,59 @@ class TestDocumentEmbeddingsResponseContract: response = DocumentEmbeddingsResponse( error=error, - chunk_ids=[] + chunks=[] ) assert response.error == error - assert response.chunk_ids == [] + assert response.chunks == [] - def test_response_translator_from_pulsar_with_chunk_ids(self): - """Test response translator converts Pulsar schema with chunk_ids to dict""" + def test_response_translator_from_pulsar_with_chunks(self): + """Test response translator converts Pulsar schema with chunks to dict""" translator = DocumentEmbeddingsResponseTranslator() response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["doc1/c1", "doc2/c2", "doc3/c3"] + chunks=[ + ChunkMatch(chunk_id="doc1/c1", score=0.95), + ChunkMatch(chunk_id="doc2/c2", score=0.85), + ChunkMatch(chunk_id="doc3/c3", score=0.75) + ] ) result = translator.from_pulsar(response) assert isinstance(result, dict) - assert "chunk_ids" in result - assert result["chunk_ids"] == ["doc1/c1", "doc2/c2", "doc3/c3"] + assert "chunks" in result + assert len(result["chunks"]) == 3 + assert result["chunks"][0]["chunk_id"] == "doc1/c1" + assert result["chunks"][0]["score"] == 0.95 - def test_response_translator_from_pulsar_with_empty_chunk_ids(self): - """Test response translator handles empty chunk_ids list""" + def test_response_translator_from_pulsar_with_empty_chunks(self): + """Test response translator handles empty chunks list""" translator = DocumentEmbeddingsResponseTranslator() response = DocumentEmbeddingsResponse( error=None, - chunk_ids=[] + chunks=[] ) result = translator.from_pulsar(response) assert isinstance(result, dict) - assert "chunk_ids" in result - assert result["chunk_ids"] == [] + assert "chunks" in result + assert result["chunks"] == [] - def test_response_translator_from_pulsar_with_none_chunk_ids(self): - """Test response translator handles None chunk_ids""" + def test_response_translator_from_pulsar_with_none_chunks(self): + """Test response translator handles None chunks""" translator = DocumentEmbeddingsResponseTranslator() response = MagicMock() - response.chunk_ids = None + response.chunks = None result = translator.from_pulsar(response) assert isinstance(result, dict) - assert "chunk_ids" not in result or result.get("chunk_ids") is None + assert "chunks" not in result or result.get("chunks") is None def test_response_translator_from_response_with_completion(self): """Test response translator with completion flag""" @@ -176,14 +188,18 @@ class TestDocumentEmbeddingsResponseContract: response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["chunk1", "chunk2"] + chunks=[ + ChunkMatch(chunk_id="chunk1", score=0.9), + ChunkMatch(chunk_id="chunk2", score=0.8) + ] ) result, is_final = translator.from_response_with_completion(response) assert isinstance(result, dict) - assert "chunk_ids" in result - assert result["chunk_ids"] == ["chunk1", "chunk2"] + assert "chunks" in result + assert len(result["chunks"]) == 2 + assert result["chunks"][0]["chunk_id"] == "chunk1" assert is_final is True # Document embeddings responses are always final def test_response_translator_to_pulsar_not_implemented(self): @@ -191,7 +207,7 @@ class TestDocumentEmbeddingsResponseContract: translator = DocumentEmbeddingsResponseTranslator() with pytest.raises(NotImplementedError): - translator.to_pulsar({"chunk_ids": ["test"]}) + translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]}) class TestDocumentEmbeddingsMessageCompatibility: @@ -201,7 +217,7 @@ class TestDocumentEmbeddingsMessageCompatibility: """Test complete request-response flow maintains data integrity""" # Create request request_data = { - "vectors": [[0.1, 0.2, 0.3]], + "vector": [0.1, 0.2, 0.3], "limit": 5, "user": "test_user", "collection": "test_collection" @@ -214,7 +230,10 @@ class TestDocumentEmbeddingsMessageCompatibility: # Simulate service processing and creating response response = DocumentEmbeddingsResponse( error=None, - chunk_ids=["doc1/c1", "doc2/c2"] + chunks=[ + ChunkMatch(chunk_id="doc1/c1", score=0.95), + ChunkMatch(chunk_id="doc2/c2", score=0.85) + ] ) # Convert response back to dict @@ -224,8 +243,8 @@ class TestDocumentEmbeddingsMessageCompatibility: # Verify data integrity assert isinstance(pulsar_request, DocumentEmbeddingsRequest) assert isinstance(response_data, dict) - assert "chunk_ids" in response_data - assert len(response_data["chunk_ids"]) == 2 + assert "chunks" in response_data + assert len(response_data["chunks"]) == 2 def test_error_response_flow(self): """Test error response flow""" @@ -237,7 +256,7 @@ class TestDocumentEmbeddingsMessageCompatibility: response = DocumentEmbeddingsResponse( error=error, - chunk_ids=[] + chunks=[] ) # Convert response to dict @@ -246,6 +265,6 @@ class TestDocumentEmbeddingsMessageCompatibility: # Verify error handling assert isinstance(response_data, dict) - # The translator doesn't include error in the dict, only chunk_ids - assert "chunk_ids" in response_data - assert response_data["chunk_ids"] == [] + # The translator doesn't include error in the dict, only chunks + assert "chunks" in response_data + assert response_data["chunks"] == [] diff --git a/tests/contract/test_structured_data_contracts.py b/tests/contract/test_structured_data_contracts.py index 71ccd787..97197f13 100644 --- a/tests/contract/test_structured_data_contracts.py +++ b/tests/contract/test_structured_data_contracts.py @@ -285,11 +285,11 @@ class TestStructuredEmbeddingsContracts: collection="test_collection", metadata=[] ) - + # Act embedding = StructuredObjectEmbedding( metadata=metadata, - vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + vector=[0.1, 0.2, 0.3], schema_name="customer_records", object_id="customer_123", field_embeddings={ @@ -301,7 +301,7 @@ class TestStructuredEmbeddingsContracts: # Assert assert embedding.schema_name == "customer_records" assert embedding.object_id == "customer_123" - assert len(embedding.vectors) == 2 + assert len(embedding.vector) == 3 assert len(embedding.field_embeddings) == 2 assert "name" in embedding.field_embeddings