feat: adding automated tests
This commit is contained in:
parent
e484f12228
commit
29ee360082
18 changed files with 2886 additions and 4 deletions
116
test/test_unit_context.py
Normal file
116
test/test_unit_context.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Unit tests for context-window trimming logic."""
|
||||
import pytest
|
||||
import router
|
||||
|
||||
|
||||
def _msgs(roles_contents):
|
||||
return [{"role": r, "content": c} for r, c in roles_contents]
|
||||
|
||||
|
||||
class TestCountMessageTokens:
|
||||
def test_returns_int(self):
|
||||
msgs = _msgs([("user", "hello")])
|
||||
assert isinstance(router._count_message_tokens(msgs), int)
|
||||
|
||||
def test_empty_list(self):
|
||||
assert router._count_message_tokens([]) >= 0
|
||||
|
||||
def test_longer_content_more_tokens(self):
|
||||
short = _msgs([("user", "hi")])
|
||||
long_ = _msgs([("user", "a " * 500)])
|
||||
assert router._count_message_tokens(long_) > router._count_message_tokens(short)
|
||||
|
||||
def test_list_content(self):
|
||||
msgs = [{"role": "user", "content": [
|
||||
{"type": "text", "text": "what do you see?"},
|
||||
]}]
|
||||
tokens = router._count_message_tokens(msgs)
|
||||
assert tokens > 0
|
||||
|
||||
def test_multiple_messages(self):
|
||||
msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
|
||||
assert router._count_message_tokens(msgs) > 10
|
||||
|
||||
|
||||
class TestTrimMessagesForContext:
|
||||
def test_short_history_unchanged(self):
|
||||
msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=4096)
|
||||
assert result == msgs
|
||||
|
||||
def test_system_messages_always_kept(self):
|
||||
msgs = (
|
||||
_msgs([("system", "you are helpful")])
|
||||
+ _msgs([("user", f"msg {i}") for i in range(50)])
|
||||
+ _msgs([("user", "final question")])
|
||||
)
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=512)
|
||||
system_msgs = [m for m in result if m["role"] == "system"]
|
||||
assert len(system_msgs) == 1
|
||||
assert system_msgs[0]["content"] == "you are helpful"
|
||||
|
||||
def test_last_user_message_always_kept(self):
|
||||
msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=256)
|
||||
assert result[-1]["content"] == "very important last question"
|
||||
|
||||
def test_oldest_dropped_first(self):
|
||||
msgs = _msgs([
|
||||
("user", "oldest msg"),
|
||||
("assistant", "oldest reply"),
|
||||
("user", "newer msg"),
|
||||
("assistant", "newer reply"),
|
||||
("user", "newest"),
|
||||
])
|
||||
# Use very small target to force trimming
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
|
||||
contents = [m["content"] for m in result]
|
||||
# "oldest msg" should be dropped before "newest"
|
||||
if "oldest msg" in contents:
|
||||
assert "newest" in contents
|
||||
else:
|
||||
assert "newest" in contents
|
||||
|
||||
def test_result_starts_with_user(self):
|
||||
msgs = _msgs([
|
||||
("assistant", "leftover assistant"),
|
||||
("user", "question"),
|
||||
])
|
||||
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
|
||||
if result:
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_target_tokens_overrides_safety_margin(self):
|
||||
msgs = _msgs([("user", "a " * 200)])
|
||||
result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
|
||||
result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
|
||||
# Both should return at least the last message
|
||||
assert len(result_small) >= 1
|
||||
assert len(result_large) >= 1
|
||||
|
||||
|
||||
class TestCalibratedTrimTarget:
|
||||
def test_returns_positive_int(self):
|
||||
msgs = [{"role": "user", "content": "hello " * 100}]
|
||||
result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
|
||||
assert isinstance(result, int)
|
||||
assert result >= 1
|
||||
|
||||
def test_over_limit_reduces_target(self):
|
||||
msgs = [{"role": "user", "content": "a " * 500}]
|
||||
# actual_tokens > n_ctx means we need to shed more
|
||||
target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
|
||||
assert target < router._count_message_tokens(msgs)
|
||||
|
||||
def test_well_within_limit_returns_current(self):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
# actual_tokens << n_ctx means nothing to shed
|
||||
target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
|
||||
# Should return cur_tiktoken since to_shed == 0
|
||||
assert target == max(1, router._count_message_tokens(msgs))
|
||||
|
||||
def test_minimum_is_one(self):
|
||||
# Even if we need to shed everything, result is at least 1
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
|
||||
assert target >= 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue