nomyo-router/test/test_unit_context.py

116 lines
4.6 KiB
Python

"""Unit tests for context-window trimming logic."""
import pytest
import router
def _msgs(roles_contents):
return [{"role": r, "content": c} for r, c in roles_contents]
class TestCountMessageTokens:
def test_returns_int(self):
msgs = _msgs([("user", "hello")])
assert isinstance(router._count_message_tokens(msgs), int)
def test_empty_list(self):
assert router._count_message_tokens([]) >= 0
def test_longer_content_more_tokens(self):
short = _msgs([("user", "hi")])
long_ = _msgs([("user", "a " * 500)])
assert router._count_message_tokens(long_) > router._count_message_tokens(short)
def test_list_content(self):
msgs = [{"role": "user", "content": [
{"type": "text", "text": "what do you see?"},
]}]
tokens = router._count_message_tokens(msgs)
assert tokens > 0
def test_multiple_messages(self):
msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
assert router._count_message_tokens(msgs) > 10
class TestTrimMessagesForContext:
def test_short_history_unchanged(self):
msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
result = router._trim_messages_for_context(msgs, n_ctx=4096)
assert result == msgs
def test_system_messages_always_kept(self):
msgs = (
_msgs([("system", "you are helpful")])
+ _msgs([("user", f"msg {i}") for i in range(50)])
+ _msgs([("user", "final question")])
)
result = router._trim_messages_for_context(msgs, n_ctx=512)
system_msgs = [m for m in result if m["role"] == "system"]
assert len(system_msgs) == 1
assert system_msgs[0]["content"] == "you are helpful"
def test_last_user_message_always_kept(self):
msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
result = router._trim_messages_for_context(msgs, n_ctx=256)
assert result[-1]["content"] == "very important last question"
def test_oldest_dropped_first(self):
msgs = _msgs([
("user", "oldest msg"),
("assistant", "oldest reply"),
("user", "newer msg"),
("assistant", "newer reply"),
("user", "newest"),
])
# Use very small target to force trimming
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
contents = [m["content"] for m in result]
# "oldest msg" should be dropped before "newest"
if "oldest msg" in contents:
assert "newest" in contents
else:
assert "newest" in contents
def test_result_starts_with_user(self):
msgs = _msgs([
("assistant", "leftover assistant"),
("user", "question"),
])
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
if result:
assert result[0]["role"] == "user"
def test_target_tokens_overrides_safety_margin(self):
msgs = _msgs([("user", "a " * 200)])
result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
# Both should return at least the last message
assert len(result_small) >= 1
assert len(result_large) >= 1
class TestCalibratedTrimTarget:
def test_returns_positive_int(self):
msgs = [{"role": "user", "content": "hello " * 100}]
result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
assert isinstance(result, int)
assert result >= 1
def test_over_limit_reduces_target(self):
msgs = [{"role": "user", "content": "a " * 500}]
# actual_tokens > n_ctx means we need to shed more
target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
assert target < router._count_message_tokens(msgs)
def test_well_within_limit_returns_current(self):
msgs = [{"role": "user", "content": "hi"}]
# actual_tokens << n_ctx means nothing to shed
target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
# Should return cur_tiktoken since to_shed == 0
assert target == max(1, router._count_message_tokens(msgs))
def test_minimum_is_one(self):
# Even if we need to shed everything, result is at least 1
msgs = [{"role": "user", "content": "hello"}]
target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
assert target >= 1