116 lines
4.6 KiB
Python
116 lines
4.6 KiB
Python
"""Unit tests for context-window trimming logic."""
|
|
import pytest
|
|
import router
|
|
|
|
|
|
def _msgs(roles_contents):
|
|
return [{"role": r, "content": c} for r, c in roles_contents]
|
|
|
|
|
|
class TestCountMessageTokens:
|
|
def test_returns_int(self):
|
|
msgs = _msgs([("user", "hello")])
|
|
assert isinstance(router._count_message_tokens(msgs), int)
|
|
|
|
def test_empty_list(self):
|
|
assert router._count_message_tokens([]) >= 0
|
|
|
|
def test_longer_content_more_tokens(self):
|
|
short = _msgs([("user", "hi")])
|
|
long_ = _msgs([("user", "a " * 500)])
|
|
assert router._count_message_tokens(long_) > router._count_message_tokens(short)
|
|
|
|
def test_list_content(self):
|
|
msgs = [{"role": "user", "content": [
|
|
{"type": "text", "text": "what do you see?"},
|
|
]}]
|
|
tokens = router._count_message_tokens(msgs)
|
|
assert tokens > 0
|
|
|
|
def test_multiple_messages(self):
|
|
msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
|
|
assert router._count_message_tokens(msgs) > 10
|
|
|
|
|
|
class TestTrimMessagesForContext:
|
|
def test_short_history_unchanged(self):
|
|
msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
|
|
result = router._trim_messages_for_context(msgs, n_ctx=4096)
|
|
assert result == msgs
|
|
|
|
def test_system_messages_always_kept(self):
|
|
msgs = (
|
|
_msgs([("system", "you are helpful")])
|
|
+ _msgs([("user", f"msg {i}") for i in range(50)])
|
|
+ _msgs([("user", "final question")])
|
|
)
|
|
result = router._trim_messages_for_context(msgs, n_ctx=512)
|
|
system_msgs = [m for m in result if m["role"] == "system"]
|
|
assert len(system_msgs) == 1
|
|
assert system_msgs[0]["content"] == "you are helpful"
|
|
|
|
def test_last_user_message_always_kept(self):
|
|
msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
|
|
result = router._trim_messages_for_context(msgs, n_ctx=256)
|
|
assert result[-1]["content"] == "very important last question"
|
|
|
|
def test_oldest_dropped_first(self):
|
|
msgs = _msgs([
|
|
("user", "oldest msg"),
|
|
("assistant", "oldest reply"),
|
|
("user", "newer msg"),
|
|
("assistant", "newer reply"),
|
|
("user", "newest"),
|
|
])
|
|
# Use very small target to force trimming
|
|
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
|
|
contents = [m["content"] for m in result]
|
|
# "oldest msg" should be dropped before "newest"
|
|
if "oldest msg" in contents:
|
|
assert "newest" in contents
|
|
else:
|
|
assert "newest" in contents
|
|
|
|
def test_result_starts_with_user(self):
|
|
msgs = _msgs([
|
|
("assistant", "leftover assistant"),
|
|
("user", "question"),
|
|
])
|
|
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
|
|
if result:
|
|
assert result[0]["role"] == "user"
|
|
|
|
def test_target_tokens_overrides_safety_margin(self):
|
|
msgs = _msgs([("user", "a " * 200)])
|
|
result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
|
|
result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
|
|
# Both should return at least the last message
|
|
assert len(result_small) >= 1
|
|
assert len(result_large) >= 1
|
|
|
|
|
|
class TestCalibratedTrimTarget:
|
|
def test_returns_positive_int(self):
|
|
msgs = [{"role": "user", "content": "hello " * 100}]
|
|
result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
|
|
assert isinstance(result, int)
|
|
assert result >= 1
|
|
|
|
def test_over_limit_reduces_target(self):
|
|
msgs = [{"role": "user", "content": "a " * 500}]
|
|
# actual_tokens > n_ctx means we need to shed more
|
|
target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
|
|
assert target < router._count_message_tokens(msgs)
|
|
|
|
def test_well_within_limit_returns_current(self):
|
|
msgs = [{"role": "user", "content": "hi"}]
|
|
# actual_tokens << n_ctx means nothing to shed
|
|
target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
|
|
# Should return cur_tiktoken since to_shed == 0
|
|
assert target == max(1, router._count_message_tokens(msgs))
|
|
|
|
def test_minimum_is_one(self):
|
|
# Even if we need to shed everything, result is at least 1
|
|
msgs = [{"role": "user", "content": "hello"}]
|
|
target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
|
|
assert target >= 1
|