add compress_type to config

This commit is contained in:
garylin2099 2024-08-02 16:43:27 +08:00
parent 1f73051571
commit 987d90f6ff
4 changed files with 30 additions and 4 deletions

View file

@ -86,6 +86,9 @@ class LLMConfig(YamlModel):
# Cost Control
calc_usage: bool = True
# Compress request messages under token limit
compress_type: str = ""
@field_validator("api_key")
@classmethod
def check_llm_key(cls, v):

View file

@ -148,7 +148,7 @@ class BaseLLM(ABC):
else:
message.extend(msg)
logger.debug(message)
compressed_message = self.compress_messages(message, compress_type="post_cut_by_token")
compressed_message = self.compress_messages(message, compress_type=self.config.compress_type)
rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout))
# rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp

View file

@ -147,6 +147,13 @@ def test_compress_messages_long(compress_type):
assert compressed[2]["role"] != "system"
def test_long_messages_no_compress():
base_llm = MockBaseLLM()
messages = [{"role": "user", "content": "1" * 10000}] * 10000
compressed = base_llm.compress_messages(messages)
assert len(compressed) == len(messages)
@pytest.mark.parametrize(
"compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"]
)

View file

@ -178,7 +178,7 @@ def test_count_tokens():
llm._assistant_msg("assistant 2"),
]
cnt = llm.count_tokens(messages)
print(cnt)
assert cnt == 47
def test_count_tokens_long():
@ -190,11 +190,11 @@ def test_count_tokens_long():
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
cnt = llm.count_tokens(messages) # 299023, ~300k
print(cnt)
assert 290000 <= cnt <= 300000
llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens
cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference
print(cnt)
assert 290000 <= cnt <= 300000
@pytest.mark.skip
@ -202,9 +202,25 @@ def test_count_tokens_long():
async def test_aask_long():
llm = LLM()
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
llm.config.compress_type = "post_cut_by_token"
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
messages = [
llm._system_msg("You are a helpful assistant"),
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
await llm.aask(messages) # should not fail with context truncated
@pytest.mark.skip
@pytest.mark.asyncio
async def test_aask_long_no_compress():
llm = LLM()
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
# Not specifying llm.config.compress_type will use default "", no compress
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
messages = [
llm._system_msg("You are a helpful assistant"),
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
with pytest.raises(Exception):
await llm.aask(messages) # should fail