trustgraph-base .chunks / .documents confusion in the API (#485)

Back-ported to 1.2 from 1.3 (#481)

* trustgraph-base .chunks / .documents confusion in the API

* Added tests, fixed test failures in code

* Fix file dup error

* Fix contract error
This commit is contained in:
cybermaggedon 2025-09-03 16:50:54 +01:00 committed by GitHub
parent 056c702515
commit 0ae39aef7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 639 additions and 9 deletions

View file

@ -0,0 +1,261 @@
"""
Contract tests for document embeddings message schemas and translators
Ensures that message formats remain consistent across services
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
from trustgraph.messaging.translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator,
DocumentEmbeddingsResponseTranslator
)
class TestDocumentEmbeddingsRequestContract:
"""Test DocumentEmbeddingsRequest schema contract"""
def test_request_schema_fields(self):
"""Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request
request = DocumentEmbeddingsRequest(
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=10,
user="test_user",
collection="test_collection"
)
# Verify all expected fields exist
assert hasattr(request, 'vectors')
assert hasattr(request, 'limit')
assert hasattr(request, 'user')
assert hasattr(request, 'collection')
# Verify field values
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert request.limit == 10
assert request.user == "test_user"
assert request.collection == "test_collection"
def test_request_translator_to_pulsar(self):
"""Test request translator converts dict to Pulsar schema"""
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2], [0.3, 0.4]],
"limit": 5,
"user": "custom_user",
"collection": "custom_collection"
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
assert result.limit == 5
assert result.user == "custom_user"
assert result.collection == "custom_collection"
def test_request_translator_to_pulsar_with_defaults(self):
"""Test request translator uses correct defaults"""
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2]]
# No limit, user, or collection provided
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2]]
assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default
def test_request_translator_from_pulsar(self):
"""Test request translator converts Pulsar schema to dict"""
translator = DocumentEmbeddingsRequestTranslator()
request = DocumentEmbeddingsRequest(
vectors=[[0.5, 0.6]],
limit=20,
user="test_user",
collection="test_collection"
)
result = translator.from_pulsar(request)
assert isinstance(result, dict)
assert result["vectors"] == [[0.5, 0.6]]
assert result["limit"] == 20
assert result["user"] == "test_user"
assert result["collection"] == "test_collection"
class TestDocumentEmbeddingsResponseContract:
"""Test DocumentEmbeddingsResponse schema contract"""
def test_response_schema_fields(self):
"""Test that DocumentEmbeddingsResponse has expected fields"""
# Create a response with chunks
response = DocumentEmbeddingsResponse(
error=None,
chunks=["chunk1", "chunk2", "chunk3"]
)
# Verify all expected fields exist
assert hasattr(response, 'error')
assert hasattr(response, 'chunks')
# Verify field values
assert response.error is None
assert response.chunks == ["chunk1", "chunk2", "chunk3"]
def test_response_schema_with_error(self):
"""Test response schema with error"""
error = Error(
type="query_error",
message="Database connection failed"
)
response = DocumentEmbeddingsResponse(
error=error,
chunks=None
)
assert response.error == error
assert response.chunks is None
def test_response_translator_from_pulsar_with_chunks(self):
"""Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunks=["doc1", "doc2", "doc3"]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["doc1", "doc2", "doc3"]
def test_response_translator_from_pulsar_with_bytes(self):
"""Test response translator handles byte chunks correctly"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = [b"byte_chunk1", b"byte_chunk2"]
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["byte_chunk1", "byte_chunk2"]
def test_response_translator_from_pulsar_with_empty_chunks(self):
"""Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = []
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == []
def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = None
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunks=["chunk1", "chunk2"]
)
result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["chunk1", "chunk2"]
assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self):
"""Test that to_pulsar raises NotImplementedError for responses"""
translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunks": ["test"]})
class TestDocumentEmbeddingsMessageCompatibility:
"""Test compatibility between request and response messages"""
def test_request_response_flow(self):
"""Test complete request-response flow maintains data integrity"""
# Create request
request_data = {
"vectors": [[0.1, 0.2, 0.3]],
"limit": 5,
"user": "test_user",
"collection": "test_collection"
}
# Convert to Pulsar request
req_translator = DocumentEmbeddingsRequestTranslator()
pulsar_request = req_translator.to_pulsar(request_data)
# Simulate service processing and creating response
response = DocumentEmbeddingsResponse(
error=None,
chunks=["relevant chunk 1", "relevant chunk 2"]
)
# Convert response back to dict
resp_translator = DocumentEmbeddingsResponseTranslator()
response_data = resp_translator.from_pulsar(response)
# Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict)
assert "chunks" in response_data
assert len(response_data["chunks"]) == 2
def test_error_response_flow(self):
"""Test error response flow"""
# Create error response
error = Error(
type="vector_db_error",
message="Collection not found"
)
response = DocumentEmbeddingsResponse(
error=error,
chunks=None
)
# Convert response to dict
translator = DocumentEmbeddingsResponseTranslator()
response_data = translator.from_pulsar(response)
# Verify error handling
assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunks
assert "chunks" not in response_data or response_data.get("chunks") is None

View file

@ -0,0 +1,190 @@
"""
Unit tests for trustgraph.base.document_embeddings_client
Testing async document embeddings client functionality
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base.document_embeddings_client import DocumentEmbeddingsClient, DocumentEmbeddingsClientSpec
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
"""Test async document embeddings client functionality"""
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
async def test_query_success_with_chunks(self, mock_parent_init):
"""Test successful query returning chunks"""
# Arrange
mock_parent_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
# Mock the request method
client.request = AsyncMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
# Act
result = await client.query(
vectors=vectors,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vectors == vectors
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
async def test_query_with_error_raises_exception(self, mock_parent_init):
"""Test query raises RuntimeError when response contains error"""
# Arrange
mock_parent_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = MagicMock()
mock_response.error.message = "Database connection failed"
client.request = AsyncMock(return_value=mock_response)
# Act & Assert
with pytest.raises(RuntimeError, match="Database connection failed"):
await client.query(
vectors=[[0.1, 0.2, 0.3]],
limit=5
)
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
async def test_query_with_empty_chunks(self, mock_parent_init):
"""Test query with empty chunks list"""
# Arrange
mock_parent_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = []
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
# Assert
assert result == []
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
async def test_query_with_default_parameters(self, mock_parent_init):
"""Test query uses correct default parameters"""
# Arrange
mock_parent_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
# Assert
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert call_args.limit == 20 # Default limit
assert call_args.user == "trustgraph" # Default user
assert call_args.collection == "default" # Default collection
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
async def test_query_with_custom_timeout(self, mock_parent_init):
"""Test query passes custom timeout to request"""
# Arrange
mock_parent_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["chunk1"]
client.request = AsyncMock(return_value=mock_response)
# Act
await client.query(
vectors=[[0.1, 0.2, 0.3]],
timeout=60
)
# Assert
assert client.request.call_args[1]["timeout"] == 60
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
async def test_query_logging(self, mock_parent_init):
"""Test query logs response for debugging"""
# Arrange
mock_parent_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
# Assert
mock_logger.debug.assert_called_once()
assert "Document embeddings response" in str(mock_logger.debug.call_args)
assert result == ["test_chunk"]
class TestDocumentEmbeddingsClientSpec(IsolatedAsyncioTestCase):
"""Test DocumentEmbeddingsClientSpec configuration"""
def test_spec_initialization(self):
"""Test DocumentEmbeddingsClientSpec initialization"""
# Act
spec = DocumentEmbeddingsClientSpec(
request_name="test-request",
response_name="test-response"
)
# Assert
assert spec.request_name == "test-request"
assert spec.response_name == "test-response"
assert spec.request_schema == DocumentEmbeddingsRequest
assert spec.response_schema == DocumentEmbeddingsResponse
assert spec.impl == DocumentEmbeddingsClient
@patch('trustgraph.base.request_response_spec.RequestResponseSpec.__init__')
def test_spec_calls_parent_init(self, mock_parent_init):
"""Test spec properly calls parent class initialization"""
# Arrange
mock_parent_init.return_value = None
# Act
spec = DocumentEmbeddingsClientSpec(
request_name="test-request",
response_name="test-response"
)
# Assert
mock_parent_init.assert_called_once_with(
request_name="test-request",
request_schema=DocumentEmbeddingsRequest,
response_name="test-response",
response_schema=DocumentEmbeddingsResponse,
impl=DocumentEmbeddingsClient
)

View file

@ -0,0 +1,172 @@
"""
Unit tests for trustgraph.clients.document_embeddings_client
Testing synchronous document embeddings client functionality
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.clients.document_embeddings_client import DocumentEmbeddingsClient
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
class TestSyncDocumentEmbeddingsClient:
"""Test synchronous document embeddings client functionality"""
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
def test_client_initialization(self, mock_base_init):
"""Test client initialization with correct parameters"""
# Arrange
mock_base_init.return_value = None
# Act
client = DocumentEmbeddingsClient(
log_level=1,
subscriber="test-subscriber",
input_queue="test-input",
output_queue="test-output",
pulsar_host="pulsar://test:6650",
pulsar_api_key="test-key"
)
# Assert
mock_base_init.assert_called_once_with(
log_level=1,
subscriber="test-subscriber",
input_queue="test-input",
output_queue="test-output",
pulsar_host="pulsar://test:6650",
pulsar_api_key="test-key",
input_schema=DocumentEmbeddingsRequest,
output_schema=DocumentEmbeddingsResponse
)
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
def test_client_initialization_with_defaults(self, mock_base_init):
"""Test client initialization uses default queues when not specified"""
# Arrange
mock_base_init.return_value = None
# Act
client = DocumentEmbeddingsClient()
# Assert
call_args = mock_base_init.call_args[1]
# Check that default queues are used
assert call_args['input_queue'] is not None
assert call_args['output_queue'] is not None
assert call_args['input_schema'] == DocumentEmbeddingsRequest
assert call_args['output_schema'] == DocumentEmbeddingsResponse
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
def test_request_returns_chunks(self, mock_base_init):
"""Test request method returns chunks from response"""
# Arrange
mock_base_init.return_value = None
client = DocumentEmbeddingsClient()
# Mock the call method to return a response with chunks
mock_response = MagicMock()
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
# Act
result = client.request(
vectors=vectors,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vectors=vectors,
limit=10,
timeout=300
)
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
def test_request_with_default_parameters(self, mock_base_init):
"""Test request uses correct default parameters"""
# Arrange
mock_base_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock()
mock_response.chunks = ["test_chunk"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3]]
# Act
result = client.request(vectors=vectors)
# Assert
assert result == ["test_chunk"]
client.call.assert_called_once_with(
user="trustgraph",
collection="default",
vectors=vectors,
limit=10,
timeout=300
)
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
def test_request_with_empty_chunks(self, mock_base_init):
"""Test request handles empty chunks list"""
# Arrange
mock_base_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock()
mock_response.chunks = []
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
# Assert
assert result == []
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
def test_request_with_none_chunks(self, mock_base_init):
"""Test request handles None chunks gracefully"""
# Arrange
mock_base_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock()
mock_response.chunks = None
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
# Assert
assert result is None
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
def test_request_with_custom_timeout(self, mock_base_init):
"""Test request passes custom timeout correctly"""
# Arrange
mock_base_init.return_value = None
client = DocumentEmbeddingsClient()
mock_response = MagicMock()
mock_response.chunks = ["chunk1"]
client.call = MagicMock(return_value=mock_response)
# Act
client.request(
vectors=[[0.1, 0.2, 0.3]],
timeout=600
)
# Assert
assert client.call.call_args[1]["timeout"] == 600

View file

@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
return resp.documents return resp.chunks
class DocumentEmbeddingsClientSpec(RequestResponseSpec): class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request) docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...") logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(documents=docs, error=None) r = DocumentEmbeddingsResponse(chunks=docs, error=None)
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed") logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error", type = "document-embeddings-query-error",
message = str(e), message = str(e),
), ),
response=None, chunks=None,
) )
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})

View file

@ -47,5 +47,5 @@ class DocumentEmbeddingsClient(BaseClient):
return self.call( return self.call(
user=user, collection=collection, user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout vectors=vectors, limit=limit, timeout=timeout
).documents ).chunks

View file

@ -36,10 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.documents: if obj.chunks is not None:
result["documents"] = [ result["chunks"] = [
doc.decode("utf-8") if isinstance(doc, bytes) else doc chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
for doc in obj.documents for chunk in obj.chunks
] ]
return result return result

View file

@ -45,4 +45,11 @@ class DocumentEmbeddingsRequest(Record):
class DocumentEmbeddingsResponse(Record): class DocumentEmbeddingsResponse(Record):
error = Error() error = Error()
chunks = Array(String()) chunks = Array(String())
document_embeddings_request_queue = topic(
"non-persistent://trustgraph/document-embeddings-request"
)
document_embeddings_response_queue = topic(
"non-persistent://trustgraph/document-embeddings-response"
)