add compress_type to config

This commit is contained in:
garylin2099 2024-08-02 16:43:27 +08:00
parent 1f73051571
commit 987d90f6ff
4 changed files with 30 additions and 4 deletions

View file

@ -147,6 +147,13 @@ def test_compress_messages_long(compress_type):
assert compressed[2]["role"] != "system"
def test_long_messages_no_compress():
base_llm = MockBaseLLM()
messages = [{"role": "user", "content": "1" * 10000}] * 10000
compressed = base_llm.compress_messages(messages)
assert len(compressed) == len(messages)
@pytest.mark.parametrize(
"compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"]
)

View file

@ -178,7 +178,7 @@ def test_count_tokens():
llm._assistant_msg("assistant 2"),
]
cnt = llm.count_tokens(messages)
print(cnt)
assert cnt == 47
def test_count_tokens_long():
@ -190,11 +190,11 @@ def test_count_tokens_long():
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
cnt = llm.count_tokens(messages) # 299023, ~300k
print(cnt)
assert 290000 <= cnt <= 300000
llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens
cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference
print(cnt)
assert 290000 <= cnt <= 300000
@pytest.mark.skip
@ -202,9 +202,25 @@ def test_count_tokens_long():
async def test_aask_long():
llm = LLM()
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
llm.config.compress_type = "post_cut_by_token"
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
messages = [
llm._system_msg("You are a helpful assistant"),
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
await llm.aask(messages) # should not fail with context truncated
@pytest.mark.skip
@pytest.mark.asyncio
async def test_aask_long_no_compress():
llm = LLM()
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
# Not specifying llm.config.compress_type will use default "", no compress
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
messages = [
llm._system_msg("You are a helpful assistant"),
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
with pytest.raises(Exception):
await llm.aask(messages) # should fail