"""Unit tests for context-window trimming logic.""" import pytest import router def _msgs(roles_contents): return [{"role": r, "content": c} for r, c in roles_contents] class TestCountMessageTokens: def test_returns_int(self): msgs = _msgs([("user", "hello")]) assert isinstance(router._count_message_tokens(msgs), int) def test_empty_list(self): assert router._count_message_tokens([]) >= 0 def test_longer_content_more_tokens(self): short = _msgs([("user", "hi")]) long_ = _msgs([("user", "a " * 500)]) assert router._count_message_tokens(long_) > router._count_message_tokens(short) def test_list_content(self): msgs = [{"role": "user", "content": [ {"type": "text", "text": "what do you see?"}, ]}] tokens = router._count_message_tokens(msgs) assert tokens > 0 def test_multiple_messages(self): msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")]) assert router._count_message_tokens(msgs) > 10 class TestTrimMessagesForContext: def test_short_history_unchanged(self): msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")]) result = router._trim_messages_for_context(msgs, n_ctx=4096) assert result == msgs def test_system_messages_always_kept(self): msgs = ( _msgs([("system", "you are helpful")]) + _msgs([("user", f"msg {i}") for i in range(50)]) + _msgs([("user", "final question")]) ) result = router._trim_messages_for_context(msgs, n_ctx=512) system_msgs = [m for m in result if m["role"] == "system"] assert len(system_msgs) == 1 assert system_msgs[0]["content"] == "you are helpful" def test_last_user_message_always_kept(self): msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")]) result = router._trim_messages_for_context(msgs, n_ctx=256) assert result[-1]["content"] == "very important last question" def test_oldest_dropped_first(self): msgs = _msgs([ ("user", "oldest msg"), ("assistant", "oldest reply"), ("user", "newer msg"), ("assistant", "newer reply"), ("user", "newest"), ]) # Use very small target to force trimming result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10) contents = [m["content"] for m in result] # "oldest msg" should be dropped before "newest" if "oldest msg" in contents: assert "newest" in contents else: assert "newest" in contents def test_result_starts_with_user(self): msgs = _msgs([ ("assistant", "leftover assistant"), ("user", "question"), ]) result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20) if result: assert result[0]["role"] == "user" def test_target_tokens_overrides_safety_margin(self): msgs = _msgs([("user", "a " * 200)]) result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10) result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000) # Both should return at least the last message assert len(result_small) >= 1 assert len(result_large) >= 1 class TestCalibratedTrimTarget: def test_returns_positive_int(self): msgs = [{"role": "user", "content": "hello " * 100}] result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000) assert isinstance(result, int) assert result >= 1 def test_over_limit_reduces_target(self): msgs = [{"role": "user", "content": "a " * 500}] # actual_tokens > n_ctx means we need to shed more target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500) assert target < router._count_message_tokens(msgs) def test_well_within_limit_returns_current(self): msgs = [{"role": "user", "content": "hi"}] # actual_tokens << n_ctx means nothing to shed target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50) # Should return cur_tiktoken since to_shed == 0 assert target == max(1, router._count_message_tokens(msgs)) def test_minimum_is_one(self): # Even if we need to shed everything, result is at least 1 msgs = [{"role": "user", "content": "hello"}] target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999) assert target >= 1