200 lines
7.5 KiB
Python
200 lines
7.5 KiB
Python
"""Unit tests for message transformation functions."""
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
import router
|
|
|
|
|
|
class TestStripAssistantPrefill:
|
|
def test_removes_trailing_assistant(self):
|
|
msgs = [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "content": "prefill"},
|
|
]
|
|
result = router._strip_assistant_prefill(msgs)
|
|
assert len(result) == 1
|
|
assert result[0]["role"] == "user"
|
|
|
|
def test_keeps_non_trailing_assistant(self):
|
|
msgs = [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "content": "response"},
|
|
{"role": "user", "content": "follow-up"},
|
|
]
|
|
result = router._strip_assistant_prefill(msgs)
|
|
assert len(result) == 3
|
|
|
|
def test_empty_list_unchanged(self):
|
|
assert router._strip_assistant_prefill([]) == []
|
|
|
|
def test_single_user_message_unchanged(self):
|
|
msgs = [{"role": "user", "content": "hi"}]
|
|
assert router._strip_assistant_prefill(msgs) == msgs
|
|
|
|
|
|
class TestTransformToolCallsToOpenAI:
|
|
def test_adds_type_function(self):
|
|
msgs = [{"role": "assistant", "tool_calls": [
|
|
{"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}}
|
|
]}]
|
|
result = router.transform_tool_calls_to_openai(msgs)
|
|
tc = result[0]["tool_calls"][0]
|
|
assert tc["type"] == "function"
|
|
|
|
def test_adds_id_when_missing(self):
|
|
msgs = [{"role": "assistant", "tool_calls": [
|
|
{"function": {"name": "fn", "arguments": {}}}
|
|
]}]
|
|
result = router.transform_tool_calls_to_openai(msgs)
|
|
assert "id" in result[0]["tool_calls"][0]
|
|
|
|
def test_converts_dict_arguments_to_string(self):
|
|
msgs = [{"role": "assistant", "tool_calls": [
|
|
{"function": {"name": "fn", "arguments": {"key": "val"}}}
|
|
]}]
|
|
result = router.transform_tool_calls_to_openai(msgs)
|
|
args = result[0]["tool_calls"][0]["function"]["arguments"]
|
|
assert isinstance(args, str)
|
|
import orjson
|
|
parsed = orjson.loads(args)
|
|
assert parsed == {"key": "val"}
|
|
|
|
def test_keeps_string_arguments_unchanged(self):
|
|
msgs = [{"role": "assistant", "tool_calls": [
|
|
{"function": {"name": "fn", "arguments": '{"key": "val"}'}}
|
|
]}]
|
|
result = router.transform_tool_calls_to_openai(msgs)
|
|
args = result[0]["tool_calls"][0]["function"]["arguments"]
|
|
assert args == '{"key": "val"}'
|
|
|
|
def test_links_tool_call_id_to_tool_response(self):
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [
|
|
{"function": {"name": "get_weather", "arguments": {}}}
|
|
]},
|
|
{"role": "tool", "name": "get_weather", "content": "sunny"},
|
|
]
|
|
result = router.transform_tool_calls_to_openai(msgs)
|
|
tc_id = result[0]["tool_calls"][0]["id"]
|
|
assert result[1].get("tool_call_id") == tc_id
|
|
|
|
def test_non_tool_messages_unchanged(self):
|
|
msgs = [{"role": "user", "content": "hello"}]
|
|
result = router.transform_tool_calls_to_openai(msgs)
|
|
assert result == msgs
|
|
|
|
|
|
class TestStripImagesFromMessages:
|
|
def test_removes_image_url_parts(self):
|
|
msgs = [{"role": "user", "content": [
|
|
{"type": "text", "text": "what is this?"},
|
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
|
]}]
|
|
result = router._strip_images_from_messages(msgs)
|
|
content = result[0]["content"]
|
|
assert content == "what is this?"
|
|
|
|
def test_keeps_text_only_messages(self):
|
|
msgs = [{"role": "user", "content": "plain text"}]
|
|
result = router._strip_images_from_messages(msgs)
|
|
assert result[0]["content"] == "plain text"
|
|
|
|
def test_multiple_text_parts_kept_as_list(self):
|
|
msgs = [{"role": "user", "content": [
|
|
{"type": "text", "text": "part one"},
|
|
{"type": "text", "text": "part two"},
|
|
{"type": "image_url", "image_url": {"url": "data:..."}},
|
|
]}]
|
|
result = router._strip_images_from_messages(msgs)
|
|
content = result[0]["content"]
|
|
assert isinstance(content, list)
|
|
assert len(content) == 2
|
|
|
|
def test_all_images_removed_empty_list(self):
|
|
msgs = [{"role": "user", "content": [
|
|
{"type": "image_url", "image_url": {"url": "data:..."}},
|
|
]}]
|
|
result = router._strip_images_from_messages(msgs)
|
|
# Image-only content becomes empty list
|
|
content = result[0]["content"]
|
|
assert content == []
|
|
|
|
|
|
class TestAccumulateOpenAITcDelta:
|
|
def _make_chunk(self, index, name=None, args_fragment="", tc_id=None):
|
|
delta = MagicMock()
|
|
tc = MagicMock()
|
|
tc.index = index
|
|
tc.id = tc_id
|
|
tc.function = MagicMock()
|
|
tc.function.name = name
|
|
tc.function.arguments = args_fragment
|
|
delta.tool_calls = [tc]
|
|
chunk = MagicMock()
|
|
chunk.choices = [MagicMock(delta=delta)]
|
|
return chunk
|
|
|
|
def test_first_delta_creates_entry(self):
|
|
acc = {}
|
|
chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"')
|
|
router._accumulate_openai_tc_delta(chunk, acc)
|
|
assert 0 in acc
|
|
assert acc[0]["name"] == "my_fn"
|
|
assert acc[0]["arguments"] == '{"k"'
|
|
|
|
def test_subsequent_deltas_concatenate_args(self):
|
|
acc = {}
|
|
router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc)
|
|
router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc)
|
|
assert acc[0]["arguments"] == '{"k": "v"}'
|
|
|
|
def test_multiple_tool_calls_tracked_separately(self):
|
|
acc = {}
|
|
c1 = self._make_chunk(0, name="fn1", args_fragment="{}")
|
|
c2 = self._make_chunk(1, name="fn2", args_fragment="{}")
|
|
chunk = MagicMock()
|
|
tc1 = MagicMock()
|
|
tc1.index = 0
|
|
tc1.id = "id1"
|
|
tc1.function = MagicMock(name="fn1", arguments="{}")
|
|
tc2 = MagicMock()
|
|
tc2.index = 1
|
|
tc2.id = "id2"
|
|
tc2.function = MagicMock(name="fn2", arguments="{}")
|
|
chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))]
|
|
router._accumulate_openai_tc_delta(chunk, acc)
|
|
assert 0 in acc and 1 in acc
|
|
|
|
def test_no_choices_is_noop(self):
|
|
acc = {}
|
|
chunk = MagicMock(choices=[])
|
|
router._accumulate_openai_tc_delta(chunk, acc)
|
|
assert acc == {}
|
|
|
|
|
|
class TestBuildOllamaToolCalls:
|
|
def test_builds_from_accumulator(self):
|
|
acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}}
|
|
result = router._build_ollama_tool_calls(acc)
|
|
assert result is not None
|
|
assert len(result) == 1
|
|
assert result[0].function.name == "get_weather"
|
|
assert result[0].function.arguments == {"city": "Berlin"}
|
|
|
|
def test_invalid_json_args_becomes_empty_dict(self):
|
|
acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}}
|
|
result = router._build_ollama_tool_calls(acc)
|
|
assert result[0].function.arguments == {}
|
|
|
|
def test_empty_accumulator_returns_none(self):
|
|
assert router._build_ollama_tool_calls({}) is None
|
|
|
|
def test_preserves_order_by_index(self):
|
|
acc = {
|
|
1: {"id": "c2", "name": "fn2", "arguments": "{}"},
|
|
0: {"id": "c1", "name": "fn1", "arguments": "{}"},
|
|
}
|
|
result = router._build_ollama_tool_calls(acc)
|
|
assert result[0].function.name == "fn1"
|
|
assert result[1].function.name == "fn2"
|