diff --git a/tests/contract/test_document_embeddings_contract.py b/tests/contract/test_document_embeddings_contract.py new file mode 100644 index 00000000..e0939aaa --- /dev/null +++ b/tests/contract/test_document_embeddings_contract.py @@ -0,0 +1,261 @@ +""" +Contract tests for document embeddings message schemas and translators +Ensures that message formats remain consistent across services +""" + +import pytest +from unittest.mock import MagicMock + +from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error +from trustgraph.messaging.translators.embeddings_query import ( + DocumentEmbeddingsRequestTranslator, + DocumentEmbeddingsResponseTranslator +) + + +class TestDocumentEmbeddingsRequestContract: + """Test DocumentEmbeddingsRequest schema contract""" + + def test_request_schema_fields(self): + """Test that DocumentEmbeddingsRequest has expected fields""" + # Create a request + request = DocumentEmbeddingsRequest( + vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + limit=10, + user="test_user", + collection="test_collection" + ) + + # Verify all expected fields exist + assert hasattr(request, 'vectors') + assert hasattr(request, 'limit') + assert hasattr(request, 'user') + assert hasattr(request, 'collection') + + # Verify field values + assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + assert request.limit == 10 + assert request.user == "test_user" + assert request.collection == "test_collection" + + def test_request_translator_to_pulsar(self): + """Test request translator converts dict to Pulsar schema""" + translator = DocumentEmbeddingsRequestTranslator() + + data = { + "vectors": [[0.1, 0.2], [0.3, 0.4]], + "limit": 5, + "user": "custom_user", + "collection": "custom_collection" + } + + result = translator.to_pulsar(data) + + assert isinstance(result, DocumentEmbeddingsRequest) + assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] + assert result.limit == 5 + assert result.user == "custom_user" + assert result.collection == "custom_collection" + + def test_request_translator_to_pulsar_with_defaults(self): + """Test request translator uses correct defaults""" + translator = DocumentEmbeddingsRequestTranslator() + + data = { + "vectors": [[0.1, 0.2]] + # No limit, user, or collection provided + } + + result = translator.to_pulsar(data) + + assert isinstance(result, DocumentEmbeddingsRequest) + assert result.vectors == [[0.1, 0.2]] + assert result.limit == 10 # Default + assert result.user == "trustgraph" # Default + assert result.collection == "default" # Default + + def test_request_translator_from_pulsar(self): + """Test request translator converts Pulsar schema to dict""" + translator = DocumentEmbeddingsRequestTranslator() + + request = DocumentEmbeddingsRequest( + vectors=[[0.5, 0.6]], + limit=20, + user="test_user", + collection="test_collection" + ) + + result = translator.from_pulsar(request) + + assert isinstance(result, dict) + assert result["vectors"] == [[0.5, 0.6]] + assert result["limit"] == 20 + assert result["user"] == "test_user" + assert result["collection"] == "test_collection" + + +class TestDocumentEmbeddingsResponseContract: + """Test DocumentEmbeddingsResponse schema contract""" + + def test_response_schema_fields(self): + """Test that DocumentEmbeddingsResponse has expected fields""" + # Create a response with chunks + response = DocumentEmbeddingsResponse( + error=None, + chunks=["chunk1", "chunk2", "chunk3"] + ) + + # Verify all expected fields exist + assert hasattr(response, 'error') + assert hasattr(response, 'chunks') + + # Verify field values + assert response.error is None + assert response.chunks == ["chunk1", "chunk2", "chunk3"] + + def test_response_schema_with_error(self): + """Test response schema with error""" + error = Error( + type="query_error", + message="Database connection failed" + ) + + response = DocumentEmbeddingsResponse( + error=error, + chunks=None + ) + + assert response.error == error + assert response.chunks is None + + def test_response_translator_from_pulsar_with_chunks(self): + """Test response translator converts Pulsar schema with chunks to dict""" + translator = DocumentEmbeddingsResponseTranslator() + + response = DocumentEmbeddingsResponse( + error=None, + chunks=["doc1", "doc2", "doc3"] + ) + + result = translator.from_pulsar(response) + + assert isinstance(result, dict) + assert "chunks" in result + assert result["chunks"] == ["doc1", "doc2", "doc3"] + + def test_response_translator_from_pulsar_with_bytes(self): + """Test response translator handles byte chunks correctly""" + translator = DocumentEmbeddingsResponseTranslator() + + response = MagicMock() + response.chunks = [b"byte_chunk1", b"byte_chunk2"] + + result = translator.from_pulsar(response) + + assert isinstance(result, dict) + assert "chunks" in result + assert result["chunks"] == ["byte_chunk1", "byte_chunk2"] + + def test_response_translator_from_pulsar_with_empty_chunks(self): + """Test response translator handles empty chunks list""" + translator = DocumentEmbeddingsResponseTranslator() + + response = MagicMock() + response.chunks = [] + + result = translator.from_pulsar(response) + + assert isinstance(result, dict) + assert "chunks" in result + assert result["chunks"] == [] + + def test_response_translator_from_pulsar_with_none_chunks(self): + """Test response translator handles None chunks""" + translator = DocumentEmbeddingsResponseTranslator() + + response = MagicMock() + response.chunks = None + + result = translator.from_pulsar(response) + + assert isinstance(result, dict) + assert "chunks" not in result or result.get("chunks") is None + + def test_response_translator_from_response_with_completion(self): + """Test response translator with completion flag""" + translator = DocumentEmbeddingsResponseTranslator() + + response = DocumentEmbeddingsResponse( + error=None, + chunks=["chunk1", "chunk2"] + ) + + result, is_final = translator.from_response_with_completion(response) + + assert isinstance(result, dict) + assert "chunks" in result + assert result["chunks"] == ["chunk1", "chunk2"] + assert is_final is True # Document embeddings responses are always final + + def test_response_translator_to_pulsar_not_implemented(self): + """Test that to_pulsar raises NotImplementedError for responses""" + translator = DocumentEmbeddingsResponseTranslator() + + with pytest.raises(NotImplementedError): + translator.to_pulsar({"chunks": ["test"]}) + + +class TestDocumentEmbeddingsMessageCompatibility: + """Test compatibility between request and response messages""" + + def test_request_response_flow(self): + """Test complete request-response flow maintains data integrity""" + # Create request + request_data = { + "vectors": [[0.1, 0.2, 0.3]], + "limit": 5, + "user": "test_user", + "collection": "test_collection" + } + + # Convert to Pulsar request + req_translator = DocumentEmbeddingsRequestTranslator() + pulsar_request = req_translator.to_pulsar(request_data) + + # Simulate service processing and creating response + response = DocumentEmbeddingsResponse( + error=None, + chunks=["relevant chunk 1", "relevant chunk 2"] + ) + + # Convert response back to dict + resp_translator = DocumentEmbeddingsResponseTranslator() + response_data = resp_translator.from_pulsar(response) + + # Verify data integrity + assert isinstance(pulsar_request, DocumentEmbeddingsRequest) + assert isinstance(response_data, dict) + assert "chunks" in response_data + assert len(response_data["chunks"]) == 2 + + def test_error_response_flow(self): + """Test error response flow""" + # Create error response + error = Error( + type="vector_db_error", + message="Collection not found" + ) + + response = DocumentEmbeddingsResponse( + error=error, + chunks=None + ) + + # Convert response to dict + translator = DocumentEmbeddingsResponseTranslator() + response_data = translator.from_pulsar(response) + + # Verify error handling + assert isinstance(response_data, dict) + # The translator doesn't include error in the dict, only chunks + assert "chunks" not in response_data or response_data.get("chunks") is None \ No newline at end of file diff --git a/tests/unit/test_base/test_document_embeddings_client.py b/tests/unit/test_base/test_document_embeddings_client.py new file mode 100644 index 00000000..1c91408d --- /dev/null +++ b/tests/unit/test_base/test_document_embeddings_client.py @@ -0,0 +1,190 @@ +""" +Unit tests for trustgraph.base.document_embeddings_client +Testing async document embeddings client functionality +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.base.document_embeddings_client import DocumentEmbeddingsClient, DocumentEmbeddingsClientSpec +from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error + + +class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase): + """Test async document embeddings client functionality""" + + @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') + async def test_query_success_with_chunks(self, mock_parent_init): + """Test successful query returning chunks""" + # Arrange + mock_parent_init.return_value = None + client = DocumentEmbeddingsClient() + mock_response = MagicMock(spec=DocumentEmbeddingsResponse) + mock_response.error = None + mock_response.chunks = ["chunk1", "chunk2", "chunk3"] + + # Mock the request method + client.request = AsyncMock(return_value=mock_response) + + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + # Act + result = await client.query( + vectors=vectors, + limit=10, + user="test_user", + collection="test_collection", + timeout=30 + ) + + # Assert + assert result == ["chunk1", "chunk2", "chunk3"] + client.request.assert_called_once() + call_args = client.request.call_args[0][0] + assert isinstance(call_args, DocumentEmbeddingsRequest) + assert call_args.vectors == vectors + assert call_args.limit == 10 + assert call_args.user == "test_user" + assert call_args.collection == "test_collection" + + @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') + async def test_query_with_error_raises_exception(self, mock_parent_init): + """Test query raises RuntimeError when response contains error""" + # Arrange + mock_parent_init.return_value = None + client = DocumentEmbeddingsClient() + mock_response = MagicMock(spec=DocumentEmbeddingsResponse) + mock_response.error = MagicMock() + mock_response.error.message = "Database connection failed" + + client.request = AsyncMock(return_value=mock_response) + + # Act & Assert + with pytest.raises(RuntimeError, match="Database connection failed"): + await client.query( + vectors=[[0.1, 0.2, 0.3]], + limit=5 + ) + + @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') + async def test_query_with_empty_chunks(self, mock_parent_init): + """Test query with empty chunks list""" + # Arrange + mock_parent_init.return_value = None + client = DocumentEmbeddingsClient() + mock_response = MagicMock(spec=DocumentEmbeddingsResponse) + mock_response.error = None + mock_response.chunks = [] + + client.request = AsyncMock(return_value=mock_response) + + # Act + result = await client.query(vectors=[[0.1, 0.2, 0.3]]) + + # Assert + assert result == [] + + @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') + async def test_query_with_default_parameters(self, mock_parent_init): + """Test query uses correct default parameters""" + # Arrange + mock_parent_init.return_value = None + client = DocumentEmbeddingsClient() + mock_response = MagicMock(spec=DocumentEmbeddingsResponse) + mock_response.error = None + mock_response.chunks = ["test_chunk"] + + client.request = AsyncMock(return_value=mock_response) + + # Act + result = await client.query(vectors=[[0.1, 0.2, 0.3]]) + + # Assert + client.request.assert_called_once() + call_args = client.request.call_args[0][0] + assert call_args.limit == 20 # Default limit + assert call_args.user == "trustgraph" # Default user + assert call_args.collection == "default" # Default collection + + @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') + async def test_query_with_custom_timeout(self, mock_parent_init): + """Test query passes custom timeout to request""" + # Arrange + mock_parent_init.return_value = None + client = DocumentEmbeddingsClient() + mock_response = MagicMock(spec=DocumentEmbeddingsResponse) + mock_response.error = None + mock_response.chunks = ["chunk1"] + + client.request = AsyncMock(return_value=mock_response) + + # Act + await client.query( + vectors=[[0.1, 0.2, 0.3]], + timeout=60 + ) + + # Assert + assert client.request.call_args[1]["timeout"] == 60 + + @patch('trustgraph.base.request_response_spec.RequestResponse.__init__') + async def test_query_logging(self, mock_parent_init): + """Test query logs response for debugging""" + # Arrange + mock_parent_init.return_value = None + client = DocumentEmbeddingsClient() + mock_response = MagicMock(spec=DocumentEmbeddingsResponse) + mock_response.error = None + mock_response.chunks = ["test_chunk"] + + client.request = AsyncMock(return_value=mock_response) + + # Act + with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger: + result = await client.query(vectors=[[0.1, 0.2, 0.3]]) + + # Assert + mock_logger.debug.assert_called_once() + assert "Document embeddings response" in str(mock_logger.debug.call_args) + assert result == ["test_chunk"] + + +class TestDocumentEmbeddingsClientSpec(IsolatedAsyncioTestCase): + """Test DocumentEmbeddingsClientSpec configuration""" + + def test_spec_initialization(self): + """Test DocumentEmbeddingsClientSpec initialization""" + # Act + spec = DocumentEmbeddingsClientSpec( + request_name="test-request", + response_name="test-response" + ) + + # Assert + assert spec.request_name == "test-request" + assert spec.response_name == "test-response" + assert spec.request_schema == DocumentEmbeddingsRequest + assert spec.response_schema == DocumentEmbeddingsResponse + assert spec.impl == DocumentEmbeddingsClient + + @patch('trustgraph.base.request_response_spec.RequestResponseSpec.__init__') + def test_spec_calls_parent_init(self, mock_parent_init): + """Test spec properly calls parent class initialization""" + # Arrange + mock_parent_init.return_value = None + + # Act + spec = DocumentEmbeddingsClientSpec( + request_name="test-request", + response_name="test-response" + ) + + # Assert + mock_parent_init.assert_called_once_with( + request_name="test-request", + request_schema=DocumentEmbeddingsRequest, + response_name="test-response", + response_schema=DocumentEmbeddingsResponse, + impl=DocumentEmbeddingsClient + ) \ No newline at end of file diff --git a/tests/unit/test_clients/test_sync_document_embeddings_client.py b/tests/unit/test_clients/test_sync_document_embeddings_client.py new file mode 100644 index 00000000..5873d81c --- /dev/null +++ b/tests/unit/test_clients/test_sync_document_embeddings_client.py @@ -0,0 +1,172 @@ +""" +Unit tests for trustgraph.clients.document_embeddings_client +Testing synchronous document embeddings client functionality +""" + +import pytest +from unittest.mock import MagicMock, patch + +from trustgraph.clients.document_embeddings_client import DocumentEmbeddingsClient +from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse + + +class TestSyncDocumentEmbeddingsClient: + """Test synchronous document embeddings client functionality""" + + @patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__') + def test_client_initialization(self, mock_base_init): + """Test client initialization with correct parameters""" + # Arrange + mock_base_init.return_value = None + + # Act + client = DocumentEmbeddingsClient( + log_level=1, + subscriber="test-subscriber", + input_queue="test-input", + output_queue="test-output", + pulsar_host="pulsar://test:6650", + pulsar_api_key="test-key" + ) + + # Assert + mock_base_init.assert_called_once_with( + log_level=1, + subscriber="test-subscriber", + input_queue="test-input", + output_queue="test-output", + pulsar_host="pulsar://test:6650", + pulsar_api_key="test-key", + input_schema=DocumentEmbeddingsRequest, + output_schema=DocumentEmbeddingsResponse + ) + + @patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__') + def test_client_initialization_with_defaults(self, mock_base_init): + """Test client initialization uses default queues when not specified""" + # Arrange + mock_base_init.return_value = None + + # Act + client = DocumentEmbeddingsClient() + + # Assert + call_args = mock_base_init.call_args[1] + # Check that default queues are used + assert call_args['input_queue'] is not None + assert call_args['output_queue'] is not None + assert call_args['input_schema'] == DocumentEmbeddingsRequest + assert call_args['output_schema'] == DocumentEmbeddingsResponse + + @patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__') + def test_request_returns_chunks(self, mock_base_init): + """Test request method returns chunks from response""" + # Arrange + mock_base_init.return_value = None + client = DocumentEmbeddingsClient() + + # Mock the call method to return a response with chunks + mock_response = MagicMock() + mock_response.chunks = ["chunk1", "chunk2", "chunk3"] + client.call = MagicMock(return_value=mock_response) + + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + # Act + result = client.request( + vectors=vectors, + user="test_user", + collection="test_collection", + limit=10, + timeout=300 + ) + + # Assert + assert result == ["chunk1", "chunk2", "chunk3"] + client.call.assert_called_once_with( + user="test_user", + collection="test_collection", + vectors=vectors, + limit=10, + timeout=300 + ) + + @patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__') + def test_request_with_default_parameters(self, mock_base_init): + """Test request uses correct default parameters""" + # Arrange + mock_base_init.return_value = None + client = DocumentEmbeddingsClient() + + mock_response = MagicMock() + mock_response.chunks = ["test_chunk"] + client.call = MagicMock(return_value=mock_response) + + vectors = [[0.1, 0.2, 0.3]] + + # Act + result = client.request(vectors=vectors) + + # Assert + assert result == ["test_chunk"] + client.call.assert_called_once_with( + user="trustgraph", + collection="default", + vectors=vectors, + limit=10, + timeout=300 + ) + + @patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__') + def test_request_with_empty_chunks(self, mock_base_init): + """Test request handles empty chunks list""" + # Arrange + mock_base_init.return_value = None + client = DocumentEmbeddingsClient() + + mock_response = MagicMock() + mock_response.chunks = [] + client.call = MagicMock(return_value=mock_response) + + # Act + result = client.request(vectors=[[0.1, 0.2, 0.3]]) + + # Assert + assert result == [] + + @patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__') + def test_request_with_none_chunks(self, mock_base_init): + """Test request handles None chunks gracefully""" + # Arrange + mock_base_init.return_value = None + client = DocumentEmbeddingsClient() + + mock_response = MagicMock() + mock_response.chunks = None + client.call = MagicMock(return_value=mock_response) + + # Act + result = client.request(vectors=[[0.1, 0.2, 0.3]]) + + # Assert + assert result is None + + @patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__') + def test_request_with_custom_timeout(self, mock_base_init): + """Test request passes custom timeout correctly""" + # Arrange + mock_base_init.return_value = None + client = DocumentEmbeddingsClient() + + mock_response = MagicMock() + mock_response.chunks = ["chunk1"] + client.call = MagicMock(return_value=mock_response) + + # Act + client.request( + vectors=[[0.1, 0.2, 0.3]], + timeout=600 + ) + + # Assert + assert client.call.call_args[1]["timeout"] == 600 \ No newline at end of file diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py index 80c9d789..e76a6da6 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - return resp.documents + return resp.chunks class DocumentEmbeddingsClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index b8e7be4c..bca915e0 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): docs = await self.query_document_embeddings(request) logger.debug("Sending document embeddings query response...") - r = DocumentEmbeddingsResponse(documents=docs, error=None) + r = DocumentEmbeddingsResponse(chunks=docs, error=None) await flow("response").send(r, properties={"id": id}) logger.debug("Document embeddings query request completed") @@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): type = "document-embeddings-query-error", message = str(e), ), - response=None, + chunks=None, ) await flow("response").send(r, properties={"id": id}) diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index 14547595..124cf3c8 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -47,5 +47,5 @@ class DocumentEmbeddingsClient(BaseClient): return self.call( user=user, collection=collection, vectors=vectors, limit=limit, timeout=timeout - ).documents + ).chunks diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index d69e7bef..5f310fd0 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -36,10 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator): def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: result = {} - if obj.documents: - result["documents"] = [ - doc.decode("utf-8") if isinstance(doc, bytes) else doc - for doc in obj.documents + if obj.chunks is not None: + result["chunks"] = [ + chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + for chunk in obj.chunks ] return result diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 214a1d4b..91231ade 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -45,4 +45,11 @@ class DocumentEmbeddingsRequest(Record): class DocumentEmbeddingsResponse(Record): error = Error() - chunks = Array(String()) \ No newline at end of file + chunks = Array(String()) + +document_embeddings_request_queue = topic( + "non-persistent://trustgraph/document-embeddings-request" +) +document_embeddings_response_queue = topic( + "non-persistent://trustgraph/document-embeddings-response" +) \ No newline at end of file