mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 01:19:38 +02:00
Merge remote-tracking branch 'origin/master' into ts-port
This commit is contained in:
commit
f4d6e49217
270 changed files with 19608 additions and 4096 deletions
|
|
@ -5,7 +5,7 @@ Tests for Gateway Config Receiver
|
|||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import Mock, patch, Mock, MagicMock
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
import uuid
|
||||
|
||||
from trustgraph.gateway.config.receiver import ConfigReceiver
|
||||
|
|
@ -23,174 +23,237 @@ class TestConfigReceiver:
|
|||
def test_config_receiver_initialization(self):
|
||||
"""Test ConfigReceiver initialization"""
|
||||
mock_backend = Mock()
|
||||
|
||||
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
||||
assert config_receiver.backend == mock_backend
|
||||
assert config_receiver.flow_handlers == []
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.config_version == 0
|
||||
|
||||
def test_add_handler(self):
|
||||
"""Test adding flow handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
|
||||
handler1 = Mock()
|
||||
handler2 = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
assert len(config_receiver.flow_handlers) == 2
|
||||
assert handler1 in config_receiver.flow_handlers
|
||||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_new_flows(self):
|
||||
"""Test on_config method with new flows"""
|
||||
async def test_on_config_notify_new_version(self):
|
||||
"""Test on_config_notify triggers fetch for newer version"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
start_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with flows
|
||||
config_receiver.config_version = 1
|
||||
|
||||
# Mock fetch_and_apply
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with newer version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}',
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows were added
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []}
|
||||
assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []}
|
||||
|
||||
# Verify start_flow was called for each new flow
|
||||
assert len(start_flow_calls) == 2
|
||||
assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls
|
||||
assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls
|
||||
mock_msg.value.return_value = Mock(version=2, types=["flow"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_removed_flows(self):
|
||||
"""Test on_config method with removed flows"""
|
||||
async def test_on_config_notify_old_version_ignored(self):
|
||||
"""Test on_config_notify ignores older versions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message with only flow1 (flow2 removed)
|
||||
config_receiver.config_version = 5
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with older version
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flow2 was removed
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
|
||||
# Verify stop_flow was called for removed flow
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []})
|
||||
mock_msg.value.return_value = Mock(version=3, types=["flow"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_no_flows(self):
|
||||
"""Test on_config method with no flows in config"""
|
||||
async def test_on_config_notify_irrelevant_types_ignored(self):
|
||||
"""Test on_config_notify ignores types the gateway doesn't care about"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow and stop_flow methods with async functions
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
async def mock_stop_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message without flows
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
# Create notify message with non-flow type
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify no flows were added
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
# Since no flows were in the config, the flow methods shouldn't be called
|
||||
# (We can't easily assert this with simple async functions, but the test
|
||||
# passes if no exceptions are thrown)
|
||||
mock_msg.value.return_value = Mock(version=2, types=["prompt"])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
# Version should be updated but no fetch
|
||||
assert len(fetch_calls) == 0
|
||||
assert config_receiver.config_version == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_exception_handling(self):
|
||||
"""Test on_config method handles exceptions gracefully"""
|
||||
async def test_on_config_notify_flow_type_triggers_fetch(self):
|
||||
"""Test on_config_notify fetches for flow-related types"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create mock message that will cause an exception
|
||||
config_receiver.config_version = 1
|
||||
|
||||
fetch_calls = []
|
||||
async def mock_fetch(**kwargs):
|
||||
fetch_calls.append(kwargs)
|
||||
config_receiver.fetch_and_apply = mock_fetch
|
||||
|
||||
for type_name in ["flow", "active-flow"]:
|
||||
fetch_calls.clear()
|
||||
config_receiver.config_version = 1
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(version=2, types=[type_name])
|
||||
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
assert len(fetch_calls) == 1, f"Expected fetch for type {type_name}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_notify_exception_handling(self):
|
||||
"""Test on_config_notify handles exceptions gracefully"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create notify message that causes an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows remain empty
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.on_config_notify(mock_msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_new_flows(self):
|
||||
"""Test fetch_and_apply starts new flows"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock _create_config_client to return a mock client
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}',
|
||||
"flow2": '{"name": "test_flow_2"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_flow_calls = []
|
||||
async def mock_start_flow(id, flow):
|
||||
start_flow_calls.append((id, flow))
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert config_receiver.config_version == 5
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert len(start_flow_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_removed_flows(self):
|
||||
"""Test fetch_and_apply stops removed flows"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
}
|
||||
|
||||
# Config now only has flow1
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow1": '{"name": "test_flow_1"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
stop_flow_calls = []
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_flow_calls.append((id, flow))
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0][0] == "flow2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_apply_with_no_flows(self):
|
||||
"""Test fetch_and_apply with empty config"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 1
|
||||
mock_resp.config = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert config_receiver.flows == {}
|
||||
assert config_receiver.config_version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
|
|
@ -199,19 +262,17 @@ class TestConfigReceiver:
|
|||
"""Test start_flow method handles handler exceptions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -219,21 +280,19 @@ class TestConfigReceiver:
|
|||
"""Test stop_flow method with multiple handlers"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
|
|
@ -242,167 +301,77 @@ class TestConfigReceiver:
|
|||
"""Test stop_flow method handles handler exceptions"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
|
||||
# Should not raise
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_loader_creates_consumer(self):
|
||||
"""Test config_loader method creates Pulsar consumer"""
|
||||
mock_backend = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
# Temporarily restore the real config_loader for this test
|
||||
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
||||
|
||||
# Mock Consumer class
|
||||
with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_consumer = Mock()
|
||||
async def mock_start():
|
||||
pass
|
||||
mock_consumer.start = mock_start
|
||||
mock_consumer_class.return_value = mock_consumer
|
||||
|
||||
# Create a task that will complete quickly
|
||||
async def quick_task():
|
||||
await config_receiver.config_loader()
|
||||
|
||||
# Run the task with a timeout to prevent hanging
|
||||
try:
|
||||
await asyncio.wait_for(quick_task(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
# This is expected since the method runs indefinitely
|
||||
pass
|
||||
|
||||
# Verify Consumer was created with correct parameters
|
||||
mock_consumer_class.assert_called_once()
|
||||
call_args = mock_consumer_class.call_args
|
||||
|
||||
assert call_args[1]['backend'] == mock_backend
|
||||
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
||||
assert call_args[1]['handler'] == config_receiver.on_config
|
||||
assert call_args[1]['start_of_messages'] is True
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_config_loader_task(self, mock_create_task):
|
||||
"""Test start method creates config loader task"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock create_task to avoid actually creating tasks with real coroutines
|
||||
|
||||
mock_task = Mock()
|
||||
mock_create_task.return_value = mock_task
|
||||
|
||||
|
||||
await config_receiver.start()
|
||||
|
||||
# Verify task was created
|
||||
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
# Verify the argument passed to create_task is a coroutine
|
||||
call_args = mock_create_task.call_args[0]
|
||||
assert len(call_args) == 1 # Should have one argument (the coroutine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_mixed_flow_operations(self):
|
||||
"""Test on_config with mixed add/remove operations"""
|
||||
async def test_fetch_and_apply_mixed_flow_operations(self):
|
||||
"""Test fetch_and_apply with mixed add/remove operations"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
|
||||
# Pre-populate
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
"flow1": {"name": "test_flow_1"},
|
||||
"flow2": {"name": "test_flow_2"}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using Mock
|
||||
start_flow_calls = []
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
# Directly assign to avoid patch.object detecting async methods
|
||||
original_start_flow = config_receiver.start_flow
|
||||
original_stop_flow = config_receiver.stop_flow
|
||||
|
||||
# Config removes flow1, keeps flow2, adds flow3
|
||||
mock_resp = Mock()
|
||||
mock_resp.error = None
|
||||
mock_resp.version = 5
|
||||
mock_resp.config = {
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2"}',
|
||||
"flow3": '{"name": "test_flow_3"}'
|
||||
}
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = mock_resp
|
||||
config_receiver._create_config_client = Mock(return_value=mock_client)
|
||||
|
||||
start_calls = []
|
||||
stop_calls = []
|
||||
|
||||
async def mock_start_flow(id, flow):
|
||||
start_calls.append((id, flow))
|
||||
async def mock_stop_flow(id, flow):
|
||||
stop_calls.append((id, flow))
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
try:
|
||||
|
||||
# Create mock message with flow1 removed and flow3 added
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}',
|
||||
"flow3": '{"name": "test_flow_3", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify final state
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
|
||||
# Verify operations
|
||||
assert len(start_flow_calls) == 1
|
||||
assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []})
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []})
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
config_receiver.start_flow = original_start_flow
|
||||
config_receiver.stop_flow = original_stop_flow
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_invalid_json_flow_data(self):
|
||||
"""Test on_config handles invalid JSON in flow data"""
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow method with an async function
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with invalid JSON
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flow": {
|
||||
"flow1": '{"invalid": json}', # Invalid JSON
|
||||
"flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# This should handle the exception gracefully
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# The entire operation should fail due to JSON parsing error
|
||||
# So no flows should be added
|
||||
assert config_receiver.flows == {}
|
||||
await config_receiver.fetch_and_apply()
|
||||
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
assert len(start_calls) == 1
|
||||
assert start_calls[0][0] == "flow3"
|
||||
assert len(stop_calls) == 1
|
||||
assert stop_calls[0][0] == "flow1"
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class TestConfigRequestor:
|
|||
mock_translator_registry.get_response_translator.return_value = Mock()
|
||||
|
||||
# Setup translator response
|
||||
mock_request_translator.to_pulsar.return_value = "translated_request"
|
||||
mock_request_translator.decode.return_value = "translated_request"
|
||||
|
||||
# Patch ServiceRequestor async methods with regular mocks (not AsyncMock)
|
||||
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
||||
|
|
@ -64,7 +64,7 @@ class TestConfigRequestor:
|
|||
result = requestor.to_request({"test": "body"})
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"})
|
||||
mock_request_translator.decode.assert_called_once_with({"test": "body"})
|
||||
assert result == "translated_request"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
|
|
@ -76,7 +76,7 @@ class TestConfigRequestor:
|
|||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Setup translator response
|
||||
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
||||
mock_response_translator.encode_with_completion.return_value = "translated_response"
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
backend=Mock(),
|
||||
|
|
@ -89,5 +89,5 @@ class TestConfigRequestor:
|
|||
result = requestor.from_response(mock_message)
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message)
|
||||
mock_response_translator.encode_with_completion.assert_called_once_with(mock_message)
|
||||
assert result == "translated_response"
|
||||
359
tests/unit/test_gateway/test_explain_triples.py
Normal file
359
tests/unit/test_gateway/test_explain_triples.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""
|
||||
Tests for inline explainability triples in response translators
|
||||
and ProvenanceEvent parsing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from trustgraph.schema import (
|
||||
GraphRagResponse, DocumentRagResponse, AgentResponse,
|
||||
Term, Triple, IRI, LITERAL, Error,
|
||||
)
|
||||
from trustgraph.messaging.translators.retrieval import (
|
||||
GraphRagResponseTranslator,
|
||||
DocumentRagResponseTranslator,
|
||||
)
|
||||
from trustgraph.messaging.translators.agent import (
|
||||
AgentResponseTranslator,
|
||||
)
|
||||
from trustgraph.api.types import ProvenanceEvent
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def make_triple(s_iri, p_iri, o_value, o_type=LITERAL):
|
||||
"""Create a Triple with IRI subject/predicate and typed object."""
|
||||
o = Term(type=IRI, iri=o_value) if o_type == IRI else Term(type=LITERAL, value=o_value)
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=s_iri),
|
||||
p=Term(type=IRI, iri=p_iri),
|
||||
o=o,
|
||||
)
|
||||
|
||||
|
||||
def sample_triples():
|
||||
"""A few provenance triples for a question entity."""
|
||||
return [
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
"https://trustgraph.ai/ns/GraphRagQuestion",
|
||||
o_type=IRI,
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"https://trustgraph.ai/ns/query",
|
||||
"What is the internet?",
|
||||
),
|
||||
make_triple(
|
||||
"urn:trustgraph:question:abc123",
|
||||
"http://www.w3.org/ns/prov#startedAtTime",
|
||||
"2026-04-07T09:00:00Z",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# --- GraphRag Translator ---
|
||||
|
||||
class TestGraphRagExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
triples = sample_triples()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=triples,
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
# Check first triple is properly encoded
|
||||
t = result["explain_triples"][0]
|
||||
assert t["s"]["t"] == "i"
|
||||
assert t["s"]["i"] == "urn:trustgraph:question:abc123"
|
||||
assert t["p"]["t"] == "i"
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="Some answer text",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" not in result
|
||||
|
||||
def test_explain_with_completion_returns_not_final(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_session=False,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is False
|
||||
|
||||
def test_explain_id_and_graph_included(self):
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
response = GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert result["explain_id"] == "urn:trustgraph:question:abc123"
|
||||
assert result["explain_graph"] == "urn:graph:retrieval"
|
||||
|
||||
|
||||
# --- DocumentRag Translator ---
|
||||
|
||||
class TestDocumentRagExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
||||
response = DocumentRagResponse(
|
||||
response=None,
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:docrag:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
||||
response = DocumentRagResponse(
|
||||
response="Answer text",
|
||||
message_type="chunk",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert "explain_triples" not in result
|
||||
|
||||
|
||||
# --- Agent Translator ---
|
||||
|
||||
class TestAgentExplainTriples:
|
||||
|
||||
def test_explain_triples_encoded(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
explain_triples=sample_triples(),
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
|
||||
assert "explain_triples" in result
|
||||
assert len(result["explain_triples"]) == 3
|
||||
|
||||
t = result["explain_triples"][1]
|
||||
assert t["p"]["i"] == "https://trustgraph.ai/ns/query"
|
||||
assert t["o"]["t"] == "l"
|
||||
assert t["o"]["v"] == "What is the internet?"
|
||||
|
||||
def test_explain_triples_empty_not_included(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
content="I need to think...",
|
||||
)
|
||||
|
||||
result = translator.encode(response)
|
||||
assert "explain_triples" not in result
|
||||
|
||||
def test_explain_with_completion_not_final(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is False
|
||||
|
||||
def test_explain_with_completion_final(self):
|
||||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="The answer is...",
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
assert is_final is True
|
||||
|
||||
|
||||
# --- ProvenanceEvent ---
|
||||
|
||||
class TestProvenanceEvent:
|
||||
|
||||
def test_question_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
)
|
||||
assert event.event_type == "question"
|
||||
|
||||
def test_exploration_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:exploration:abc123",
|
||||
)
|
||||
assert event.event_type == "exploration"
|
||||
|
||||
def test_focus_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:focus:abc123",
|
||||
)
|
||||
assert event.event_type == "focus"
|
||||
|
||||
def test_synthesis_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:synthesis:abc123",
|
||||
)
|
||||
assert event.event_type == "synthesis"
|
||||
|
||||
def test_grounding_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:grounding:abc123",
|
||||
)
|
||||
assert event.event_type == "grounding"
|
||||
|
||||
def test_session_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
)
|
||||
assert event.event_type == "session"
|
||||
|
||||
def test_iteration_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:iteration:abc123:1",
|
||||
)
|
||||
assert event.event_type == "iteration"
|
||||
|
||||
def test_observation_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:observation:abc123:1",
|
||||
)
|
||||
assert event.event_type == "observation"
|
||||
|
||||
def test_conclusion_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:conclusion:abc123",
|
||||
)
|
||||
assert event.event_type == "conclusion"
|
||||
|
||||
def test_decomposition_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:decomposition:abc123",
|
||||
)
|
||||
assert event.event_type == "decomposition"
|
||||
|
||||
def test_finding_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:finding:abc123:0",
|
||||
)
|
||||
assert event.event_type == "finding"
|
||||
|
||||
def test_plan_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:plan:abc123",
|
||||
)
|
||||
assert event.event_type == "plan"
|
||||
|
||||
def test_step_result_event_type(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:agent:step-result:abc123:0",
|
||||
)
|
||||
assert event.event_type == "step-result"
|
||||
|
||||
def test_defaults(self):
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
)
|
||||
assert event.entity is None
|
||||
assert event.triples == []
|
||||
assert event.explain_graph == ""
|
||||
|
||||
def test_with_triples(self):
|
||||
raw = [{"s": {"t": "i", "i": "urn:x"}, "p": {"t": "i", "i": "urn:y"}, "o": {"t": "l", "v": "z"}}]
|
||||
event = ProvenanceEvent(
|
||||
explain_id="urn:trustgraph:question:abc123",
|
||||
triples=raw,
|
||||
)
|
||||
assert len(event.triples) == 1
|
||||
|
||||
|
||||
# --- Build ProvenanceEvent with entity parsing ---
|
||||
|
||||
class TestBuildProvenanceEvent:
|
||||
|
||||
def _make_client(self):
|
||||
"""Create a minimal WebSocketClient-like object with _build_provenance_event."""
|
||||
from trustgraph.api.socket_client import WebSocketClient
|
||||
# We can't instantiate WebSocketClient easily, so test the method logic directly
|
||||
return None
|
||||
|
||||
def test_entity_parsed_from_wire_triples(self):
|
||||
"""Test that wire-format triples are parsed into an ExplainEntity."""
|
||||
from trustgraph.api.explainability import ExplainEntity
|
||||
|
||||
wire_triples = [
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
|
||||
"p": {"t": "i", "i": "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"},
|
||||
"o": {"t": "i", "i": "https://trustgraph.ai/ns/GraphRagQuestion"},
|
||||
},
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:trustgraph:question:abc123"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/query"},
|
||||
"o": {"t": "l", "v": "What is the internet?"},
|
||||
},
|
||||
]
|
||||
|
||||
# Parse triples the same way _build_provenance_event does
|
||||
parsed = []
|
||||
for t in wire_triples:
|
||||
s = t.get("s", {}).get("i", "")
|
||||
p = t.get("p", {}).get("i", "")
|
||||
o_term = t.get("o", {})
|
||||
if o_term.get("t") == "i":
|
||||
o = o_term.get("i", "")
|
||||
else:
|
||||
o = o_term.get("v", "")
|
||||
parsed.append((s, p, o))
|
||||
|
||||
entity = ExplainEntity.from_triples(
|
||||
"urn:trustgraph:question:abc123", parsed
|
||||
)
|
||||
|
||||
assert entity.entity_type == "question"
|
||||
assert entity.query == "What is the internet?"
|
||||
assert entity.question_type == "graph-rag"
|
||||
|
|
@ -25,7 +25,7 @@ from trustgraph.schema import (
|
|||
class TestGraphRagResponseTranslator:
|
||||
"""Test GraphRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
def test_encode_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -36,14 +36,14 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - Empty string should be included in result
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
def test_encode_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -54,13 +54,13 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Some text"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_response(self):
|
||||
def test_encode_with_none_response(self):
|
||||
"""Test that None response is handled correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
|
@ -71,14 +71,14 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - None should not be included
|
||||
assert "response" not in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_response_with_completion_returns_correct_flag(self):
|
||||
"""Test that from_response_with_completion returns correct is_final flag"""
|
||||
def test_encode_with_completion_returns_correct_flag(self):
|
||||
"""Test that encode_with_completion returns correct is_final flag"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response_chunk)
|
||||
result, is_final = translator.encode_with_completion(response_chunk)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
|
|
@ -105,7 +105,7 @@ class TestGraphRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
result, is_final = translator.encode_with_completion(final_response)
|
||||
|
||||
# Assert - is_final is based on end_of_session, not end_of_stream
|
||||
assert is_final is True
|
||||
|
|
@ -116,7 +116,7 @@ class TestGraphRagResponseTranslator:
|
|||
class TestDocumentRagResponseTranslator:
|
||||
"""Test DocumentRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
def test_encode_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
|
@ -127,14 +127,14 @@ class TestDocumentRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
def test_encode_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
|
|
@ -145,7 +145,7 @@ class TestDocumentRagResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Document content"
|
||||
|
|
@ -155,7 +155,7 @@ class TestDocumentRagResponseTranslator:
|
|||
class TestPromptResponseTranslator:
|
||||
"""Test PromptResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_text(self):
|
||||
def test_encode_with_empty_text(self):
|
||||
"""Test that empty text strings are preserved"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -167,14 +167,14 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "text" in result
|
||||
assert result["text"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_text(self):
|
||||
def test_encode_with_non_empty_text(self):
|
||||
"""Test that non-empty text works correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -186,13 +186,13 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert result["text"] == "Some prompt response"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_text(self):
|
||||
def test_encode_with_none_text(self):
|
||||
"""Test that None text is handled correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -204,14 +204,14 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "text" not in result
|
||||
assert "object" in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_includes_end_of_stream(self):
|
||||
def test_encode_includes_end_of_stream(self):
|
||||
"""Test that end_of_stream flag is always included"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
|
@ -225,7 +225,7 @@ class TestPromptResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result
|
||||
|
|
@ -235,7 +235,7 @@ class TestPromptResponseTranslator:
|
|||
class TestTextCompletionResponseTranslator:
|
||||
"""Test TextCompletionResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_always_includes_response(self):
|
||||
def test_encode_always_includes_response(self):
|
||||
"""Test that response field is always included, even if empty"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
|
|
@ -249,13 +249,13 @@ class TestTextCompletionResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert - Response should always be present
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
|
||||
def test_from_response_with_completion_with_empty_final(self):
|
||||
def test_encode_with_completion_with_empty_final(self):
|
||||
"""Test that empty final response is handled correctly"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
|
|
@ -269,7 +269,7 @@ class TestTextCompletionResponseTranslator:
|
|||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
result, is_final = translator.encode_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True
|
||||
|
|
@ -297,7 +297,7 @@ class TestStreamingProtocolCompliance:
|
|||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
|
||||
|
|
@ -320,7 +320,7 @@ class TestStreamingProtocolCompliance:
|
|||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
result = translator.encode(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
|
||||
|
|
|
|||
54
tests/unit/test_gateway/test_text_document_translator.py
Normal file
54
tests/unit/test_gateway/test_text_document_translator.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Unit tests for text document gateway translation compatibility.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from trustgraph.messaging.translators.document_loading import TextDocumentTranslator
|
||||
|
||||
|
||||
class TestTextDocumentTranslator:
|
||||
def test_decode_decodes_base64_text(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "Cancer survival: 2.74× higher hazard ratio"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "research",
|
||||
"charset": "utf-8",
|
||||
"text": base64.b64encode(payload.encode("utf-8")).decode("ascii"),
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.metadata.id == "doc-1"
|
||||
assert msg.metadata.user == "alice"
|
||||
assert msg.metadata.collection == "research"
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
def test_decode_accepts_raw_utf8_text(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "Cancer survival: 2.74× higher hazard ratio"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"charset": "utf-8",
|
||||
"text": payload,
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
|
||||
def test_decode_falls_back_to_raw_non_base64_ascii(self):
|
||||
translator = TextDocumentTranslator()
|
||||
payload = "plain-text payload"
|
||||
|
||||
msg = translator.decode(
|
||||
{
|
||||
"charset": "utf-8",
|
||||
"text": payload,
|
||||
}
|
||||
)
|
||||
|
||||
assert msg.text == payload.encode("utf-8")
|
||||
Loading…
Add table
Add a link
Reference in a new issue