mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
trustgraph-base .chunks / .documents confusion in the API (#485)
Back-ported to 1.2 from 1.3 (#481) * trustgraph-base .chunks / .documents confusion in the API * Added tests, fixed test failures in code * Fix file dup error * Fix contract error
This commit is contained in:
parent
056c702515
commit
0ae39aef7f
8 changed files with 639 additions and 9 deletions
261
tests/contract/test_document_embeddings_contract.py
Normal file
261
tests/contract/test_document_embeddings_contract.py
Normal file
|
|
@ -0,0 +1,261 @@
|
||||||
|
"""
|
||||||
|
Contract tests for document embeddings message schemas and translators
|
||||||
|
Ensures that message formats remain consistent across services
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
|
||||||
|
from trustgraph.messaging.translators.embeddings_query import (
|
||||||
|
DocumentEmbeddingsRequestTranslator,
|
||||||
|
DocumentEmbeddingsResponseTranslator
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentEmbeddingsRequestContract:
|
||||||
|
"""Test DocumentEmbeddingsRequest schema contract"""
|
||||||
|
|
||||||
|
def test_request_schema_fields(self):
|
||||||
|
"""Test that DocumentEmbeddingsRequest has expected fields"""
|
||||||
|
# Create a request
|
||||||
|
request = DocumentEmbeddingsRequest(
|
||||||
|
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||||
|
limit=10,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all expected fields exist
|
||||||
|
assert hasattr(request, 'vectors')
|
||||||
|
assert hasattr(request, 'limit')
|
||||||
|
assert hasattr(request, 'user')
|
||||||
|
assert hasattr(request, 'collection')
|
||||||
|
|
||||||
|
# Verify field values
|
||||||
|
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
assert request.limit == 10
|
||||||
|
assert request.user == "test_user"
|
||||||
|
assert request.collection == "test_collection"
|
||||||
|
|
||||||
|
def test_request_translator_to_pulsar(self):
|
||||||
|
"""Test request translator converts dict to Pulsar schema"""
|
||||||
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"vectors": [[0.1, 0.2], [0.3, 0.4]],
|
||||||
|
"limit": 5,
|
||||||
|
"user": "custom_user",
|
||||||
|
"collection": "custom_collection"
|
||||||
|
}
|
||||||
|
|
||||||
|
result = translator.to_pulsar(data)
|
||||||
|
|
||||||
|
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||||
|
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
|
||||||
|
assert result.limit == 5
|
||||||
|
assert result.user == "custom_user"
|
||||||
|
assert result.collection == "custom_collection"
|
||||||
|
|
||||||
|
def test_request_translator_to_pulsar_with_defaults(self):
|
||||||
|
"""Test request translator uses correct defaults"""
|
||||||
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"vectors": [[0.1, 0.2]]
|
||||||
|
# No limit, user, or collection provided
|
||||||
|
}
|
||||||
|
|
||||||
|
result = translator.to_pulsar(data)
|
||||||
|
|
||||||
|
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||||
|
assert result.vectors == [[0.1, 0.2]]
|
||||||
|
assert result.limit == 10 # Default
|
||||||
|
assert result.user == "trustgraph" # Default
|
||||||
|
assert result.collection == "default" # Default
|
||||||
|
|
||||||
|
def test_request_translator_from_pulsar(self):
|
||||||
|
"""Test request translator converts Pulsar schema to dict"""
|
||||||
|
translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
|
||||||
|
request = DocumentEmbeddingsRequest(
|
||||||
|
vectors=[[0.5, 0.6]],
|
||||||
|
limit=20,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = translator.from_pulsar(request)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["vectors"] == [[0.5, 0.6]]
|
||||||
|
assert result["limit"] == 20
|
||||||
|
assert result["user"] == "test_user"
|
||||||
|
assert result["collection"] == "test_collection"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentEmbeddingsResponseContract:
|
||||||
|
"""Test DocumentEmbeddingsResponse schema contract"""
|
||||||
|
|
||||||
|
def test_response_schema_fields(self):
|
||||||
|
"""Test that DocumentEmbeddingsResponse has expected fields"""
|
||||||
|
# Create a response with chunks
|
||||||
|
response = DocumentEmbeddingsResponse(
|
||||||
|
error=None,
|
||||||
|
chunks=["chunk1", "chunk2", "chunk3"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all expected fields exist
|
||||||
|
assert hasattr(response, 'error')
|
||||||
|
assert hasattr(response, 'chunks')
|
||||||
|
|
||||||
|
# Verify field values
|
||||||
|
assert response.error is None
|
||||||
|
assert response.chunks == ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
|
def test_response_schema_with_error(self):
|
||||||
|
"""Test response schema with error"""
|
||||||
|
error = Error(
|
||||||
|
type="query_error",
|
||||||
|
message="Database connection failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = DocumentEmbeddingsResponse(
|
||||||
|
error=error,
|
||||||
|
chunks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.error == error
|
||||||
|
assert response.chunks is None
|
||||||
|
|
||||||
|
def test_response_translator_from_pulsar_with_chunks(self):
|
||||||
|
"""Test response translator converts Pulsar schema with chunks to dict"""
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
|
response = DocumentEmbeddingsResponse(
|
||||||
|
error=None,
|
||||||
|
chunks=["doc1", "doc2", "doc3"]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "chunks" in result
|
||||||
|
assert result["chunks"] == ["doc1", "doc2", "doc3"]
|
||||||
|
|
||||||
|
def test_response_translator_from_pulsar_with_bytes(self):
|
||||||
|
"""Test response translator handles byte chunks correctly"""
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.chunks = [b"byte_chunk1", b"byte_chunk2"]
|
||||||
|
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "chunks" in result
|
||||||
|
assert result["chunks"] == ["byte_chunk1", "byte_chunk2"]
|
||||||
|
|
||||||
|
def test_response_translator_from_pulsar_with_empty_chunks(self):
|
||||||
|
"""Test response translator handles empty chunks list"""
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.chunks = []
|
||||||
|
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "chunks" in result
|
||||||
|
assert result["chunks"] == []
|
||||||
|
|
||||||
|
def test_response_translator_from_pulsar_with_none_chunks(self):
|
||||||
|
"""Test response translator handles None chunks"""
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.chunks = None
|
||||||
|
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "chunks" not in result or result.get("chunks") is None
|
||||||
|
|
||||||
|
def test_response_translator_from_response_with_completion(self):
|
||||||
|
"""Test response translator with completion flag"""
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
|
response = DocumentEmbeddingsResponse(
|
||||||
|
error=None,
|
||||||
|
chunks=["chunk1", "chunk2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
result, is_final = translator.from_response_with_completion(response)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "chunks" in result
|
||||||
|
assert result["chunks"] == ["chunk1", "chunk2"]
|
||||||
|
assert is_final is True # Document embeddings responses are always final
|
||||||
|
|
||||||
|
def test_response_translator_to_pulsar_not_implemented(self):
|
||||||
|
"""Test that to_pulsar raises NotImplementedError for responses"""
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
translator.to_pulsar({"chunks": ["test"]})
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentEmbeddingsMessageCompatibility:
|
||||||
|
"""Test compatibility between request and response messages"""
|
||||||
|
|
||||||
|
def test_request_response_flow(self):
|
||||||
|
"""Test complete request-response flow maintains data integrity"""
|
||||||
|
# Create request
|
||||||
|
request_data = {
|
||||||
|
"vectors": [[0.1, 0.2, 0.3]],
|
||||||
|
"limit": 5,
|
||||||
|
"user": "test_user",
|
||||||
|
"collection": "test_collection"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert to Pulsar request
|
||||||
|
req_translator = DocumentEmbeddingsRequestTranslator()
|
||||||
|
pulsar_request = req_translator.to_pulsar(request_data)
|
||||||
|
|
||||||
|
# Simulate service processing and creating response
|
||||||
|
response = DocumentEmbeddingsResponse(
|
||||||
|
error=None,
|
||||||
|
chunks=["relevant chunk 1", "relevant chunk 2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response back to dict
|
||||||
|
resp_translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
response_data = resp_translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Verify data integrity
|
||||||
|
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
||||||
|
assert isinstance(response_data, dict)
|
||||||
|
assert "chunks" in response_data
|
||||||
|
assert len(response_data["chunks"]) == 2
|
||||||
|
|
||||||
|
def test_error_response_flow(self):
|
||||||
|
"""Test error response flow"""
|
||||||
|
# Create error response
|
||||||
|
error = Error(
|
||||||
|
type="vector_db_error",
|
||||||
|
message="Collection not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = DocumentEmbeddingsResponse(
|
||||||
|
error=error,
|
||||||
|
chunks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response to dict
|
||||||
|
translator = DocumentEmbeddingsResponseTranslator()
|
||||||
|
response_data = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Verify error handling
|
||||||
|
assert isinstance(response_data, dict)
|
||||||
|
# The translator doesn't include error in the dict, only chunks
|
||||||
|
assert "chunks" not in response_data or response_data.get("chunks") is None
|
||||||
190
tests/unit/test_base/test_document_embeddings_client.py
Normal file
190
tests/unit/test_base/test_document_embeddings_client.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
"""
|
||||||
|
Unit tests for trustgraph.base.document_embeddings_client
|
||||||
|
Testing async document embeddings client functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
|
||||||
|
from trustgraph.base.document_embeddings_client import DocumentEmbeddingsClient, DocumentEmbeddingsClientSpec
|
||||||
|
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
||||||
|
"""Test async document embeddings client functionality"""
|
||||||
|
|
||||||
|
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||||
|
async def test_query_success_with_chunks(self, mock_parent_init):
|
||||||
|
"""Test successful query returning chunks"""
|
||||||
|
# Arrange
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
|
mock_response.error = None
|
||||||
|
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
|
# Mock the request method
|
||||||
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await client.query(
|
||||||
|
vectors=vectors,
|
||||||
|
limit=10,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||||
|
client.request.assert_called_once()
|
||||||
|
call_args = client.request.call_args[0][0]
|
||||||
|
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
||||||
|
assert call_args.vectors == vectors
|
||||||
|
assert call_args.limit == 10
|
||||||
|
assert call_args.user == "test_user"
|
||||||
|
assert call_args.collection == "test_collection"
|
||||||
|
|
||||||
|
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||||
|
async def test_query_with_error_raises_exception(self, mock_parent_init):
|
||||||
|
"""Test query raises RuntimeError when response contains error"""
|
||||||
|
# Arrange
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
|
mock_response.error = MagicMock()
|
||||||
|
mock_response.error.message = "Database connection failed"
|
||||||
|
|
||||||
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||||
|
await client.query(
|
||||||
|
vectors=[[0.1, 0.2, 0.3]],
|
||||||
|
limit=5
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||||
|
async def test_query_with_empty_chunks(self, mock_parent_init):
|
||||||
|
"""Test query with empty chunks list"""
|
||||||
|
# Arrange
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
|
mock_response.error = None
|
||||||
|
mock_response.chunks = []
|
||||||
|
|
||||||
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||||
|
async def test_query_with_default_parameters(self, mock_parent_init):
|
||||||
|
"""Test query uses correct default parameters"""
|
||||||
|
# Arrange
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
|
mock_response.error = None
|
||||||
|
mock_response.chunks = ["test_chunk"]
|
||||||
|
|
||||||
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
client.request.assert_called_once()
|
||||||
|
call_args = client.request.call_args[0][0]
|
||||||
|
assert call_args.limit == 20 # Default limit
|
||||||
|
assert call_args.user == "trustgraph" # Default user
|
||||||
|
assert call_args.collection == "default" # Default collection
|
||||||
|
|
||||||
|
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||||
|
async def test_query_with_custom_timeout(self, mock_parent_init):
|
||||||
|
"""Test query passes custom timeout to request"""
|
||||||
|
# Arrange
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
|
mock_response.error = None
|
||||||
|
mock_response.chunks = ["chunk1"]
|
||||||
|
|
||||||
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await client.query(
|
||||||
|
vectors=[[0.1, 0.2, 0.3]],
|
||||||
|
timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert client.request.call_args[1]["timeout"] == 60
|
||||||
|
|
||||||
|
@patch('trustgraph.base.request_response_spec.RequestResponse.__init__')
|
||||||
|
async def test_query_logging(self, mock_parent_init):
|
||||||
|
"""Test query logs response for debugging"""
|
||||||
|
# Arrange
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||||
|
mock_response.error = None
|
||||||
|
mock_response.chunks = ["test_chunk"]
|
||||||
|
|
||||||
|
client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
|
||||||
|
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_logger.debug.assert_called_once()
|
||||||
|
assert "Document embeddings response" in str(mock_logger.debug.call_args)
|
||||||
|
assert result == ["test_chunk"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentEmbeddingsClientSpec(IsolatedAsyncioTestCase):
|
||||||
|
"""Test DocumentEmbeddingsClientSpec configuration"""
|
||||||
|
|
||||||
|
def test_spec_initialization(self):
|
||||||
|
"""Test DocumentEmbeddingsClientSpec initialization"""
|
||||||
|
# Act
|
||||||
|
spec = DocumentEmbeddingsClientSpec(
|
||||||
|
request_name="test-request",
|
||||||
|
response_name="test-response"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert spec.request_name == "test-request"
|
||||||
|
assert spec.response_name == "test-response"
|
||||||
|
assert spec.request_schema == DocumentEmbeddingsRequest
|
||||||
|
assert spec.response_schema == DocumentEmbeddingsResponse
|
||||||
|
assert spec.impl == DocumentEmbeddingsClient
|
||||||
|
|
||||||
|
@patch('trustgraph.base.request_response_spec.RequestResponseSpec.__init__')
|
||||||
|
def test_spec_calls_parent_init(self, mock_parent_init):
|
||||||
|
"""Test spec properly calls parent class initialization"""
|
||||||
|
# Arrange
|
||||||
|
mock_parent_init.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
spec = DocumentEmbeddingsClientSpec(
|
||||||
|
request_name="test-request",
|
||||||
|
response_name="test-response"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_parent_init.assert_called_once_with(
|
||||||
|
request_name="test-request",
|
||||||
|
request_schema=DocumentEmbeddingsRequest,
|
||||||
|
response_name="test-response",
|
||||||
|
response_schema=DocumentEmbeddingsResponse,
|
||||||
|
impl=DocumentEmbeddingsClient
|
||||||
|
)
|
||||||
172
tests/unit/test_clients/test_sync_document_embeddings_client.py
Normal file
172
tests/unit/test_clients/test_sync_document_embeddings_client.py
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
"""
|
||||||
|
Unit tests for trustgraph.clients.document_embeddings_client
|
||||||
|
Testing synchronous document embeddings client functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from trustgraph.clients.document_embeddings_client import DocumentEmbeddingsClient
|
||||||
|
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncDocumentEmbeddingsClient:
|
||||||
|
"""Test synchronous document embeddings client functionality"""
|
||||||
|
|
||||||
|
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
|
||||||
|
def test_client_initialization(self, mock_base_init):
|
||||||
|
"""Test client initialization with correct parameters"""
|
||||||
|
# Arrange
|
||||||
|
mock_base_init.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
client = DocumentEmbeddingsClient(
|
||||||
|
log_level=1,
|
||||||
|
subscriber="test-subscriber",
|
||||||
|
input_queue="test-input",
|
||||||
|
output_queue="test-output",
|
||||||
|
pulsar_host="pulsar://test:6650",
|
||||||
|
pulsar_api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_base_init.assert_called_once_with(
|
||||||
|
log_level=1,
|
||||||
|
subscriber="test-subscriber",
|
||||||
|
input_queue="test-input",
|
||||||
|
output_queue="test-output",
|
||||||
|
pulsar_host="pulsar://test:6650",
|
||||||
|
pulsar_api_key="test-key",
|
||||||
|
input_schema=DocumentEmbeddingsRequest,
|
||||||
|
output_schema=DocumentEmbeddingsResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
|
||||||
|
def test_client_initialization_with_defaults(self, mock_base_init):
|
||||||
|
"""Test client initialization uses default queues when not specified"""
|
||||||
|
# Arrange
|
||||||
|
mock_base_init.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
call_args = mock_base_init.call_args[1]
|
||||||
|
# Check that default queues are used
|
||||||
|
assert call_args['input_queue'] is not None
|
||||||
|
assert call_args['output_queue'] is not None
|
||||||
|
assert call_args['input_schema'] == DocumentEmbeddingsRequest
|
||||||
|
assert call_args['output_schema'] == DocumentEmbeddingsResponse
|
||||||
|
|
||||||
|
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
|
||||||
|
def test_request_returns_chunks(self, mock_base_init):
|
||||||
|
"""Test request method returns chunks from response"""
|
||||||
|
# Arrange
|
||||||
|
mock_base_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
|
||||||
|
# Mock the call method to return a response with chunks
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = client.request(
|
||||||
|
vectors=vectors,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
limit=10,
|
||||||
|
timeout=300
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||||
|
client.call.assert_called_once_with(
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
vectors=vectors,
|
||||||
|
limit=10,
|
||||||
|
timeout=300
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
|
||||||
|
def test_request_with_default_parameters(self, mock_base_init):
|
||||||
|
"""Test request uses correct default parameters"""
|
||||||
|
# Arrange
|
||||||
|
mock_base_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.chunks = ["test_chunk"]
|
||||||
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = client.request(vectors=vectors)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == ["test_chunk"]
|
||||||
|
client.call.assert_called_once_with(
|
||||||
|
user="trustgraph",
|
||||||
|
collection="default",
|
||||||
|
vectors=vectors,
|
||||||
|
limit=10,
|
||||||
|
timeout=300
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
|
||||||
|
def test_request_with_empty_chunks(self, mock_base_init):
|
||||||
|
"""Test request handles empty chunks list"""
|
||||||
|
# Arrange
|
||||||
|
mock_base_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.chunks = []
|
||||||
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = client.request(vectors=[[0.1, 0.2, 0.3]])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
|
||||||
|
def test_request_with_none_chunks(self, mock_base_init):
|
||||||
|
"""Test request handles None chunks gracefully"""
|
||||||
|
# Arrange
|
||||||
|
mock_base_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.chunks = None
|
||||||
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = client.request(vectors=[[0.1, 0.2, 0.3]])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@patch('trustgraph.clients.document_embeddings_client.BaseClient.__init__')
|
||||||
|
def test_request_with_custom_timeout(self, mock_base_init):
|
||||||
|
"""Test request passes custom timeout correctly"""
|
||||||
|
# Arrange
|
||||||
|
mock_base_init.return_value = None
|
||||||
|
client = DocumentEmbeddingsClient()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.chunks = ["chunk1"]
|
||||||
|
client.call = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
client.request(
|
||||||
|
vectors=[[0.1, 0.2, 0.3]],
|
||||||
|
timeout=600
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert client.call.call_args[1]["timeout"] == 600
|
||||||
|
|
@ -27,7 +27,7 @@ class DocumentEmbeddingsClient(RequestResponse):
|
||||||
if resp.error:
|
if resp.error:
|
||||||
raise RuntimeError(resp.error.message)
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
return resp.documents
|
return resp.chunks
|
||||||
|
|
||||||
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||||
docs = await self.query_document_embeddings(request)
|
docs = await self.query_document_embeddings(request)
|
||||||
|
|
||||||
logger.debug("Sending document embeddings query response...")
|
logger.debug("Sending document embeddings query response...")
|
||||||
r = DocumentEmbeddingsResponse(documents=docs, error=None)
|
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
||||||
logger.debug("Document embeddings query request completed")
|
logger.debug("Document embeddings query request completed")
|
||||||
|
|
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||||
type = "document-embeddings-query-error",
|
type = "document-embeddings-query-error",
|
||||||
message = str(e),
|
message = str(e),
|
||||||
),
|
),
|
||||||
response=None,
|
chunks=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
|
||||||
|
|
@ -47,5 +47,5 @@ class DocumentEmbeddingsClient(BaseClient):
|
||||||
return self.call(
|
return self.call(
|
||||||
user=user, collection=collection,
|
user=user, collection=collection,
|
||||||
vectors=vectors, limit=limit, timeout=timeout
|
vectors=vectors, limit=limit, timeout=timeout
|
||||||
).documents
|
).chunks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,10 +36,10 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if obj.documents:
|
if obj.chunks is not None:
|
||||||
result["documents"] = [
|
result["chunks"] = [
|
||||||
doc.decode("utf-8") if isinstance(doc, bytes) else doc
|
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||||||
for doc in obj.documents
|
for chunk in obj.chunks
|
||||||
]
|
]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -45,4 +45,11 @@ class DocumentEmbeddingsRequest(Record):
|
||||||
|
|
||||||
class DocumentEmbeddingsResponse(Record):
|
class DocumentEmbeddingsResponse(Record):
|
||||||
error = Error()
|
error = Error()
|
||||||
chunks = Array(String())
|
chunks = Array(String())
|
||||||
|
|
||||||
|
document_embeddings_request_queue = topic(
|
||||||
|
"non-persistent://trustgraph/document-embeddings-request"
|
||||||
|
)
|
||||||
|
document_embeddings_response_queue = topic(
|
||||||
|
"non-persistent://trustgraph/document-embeddings-response"
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue