mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add compress_type to config
This commit is contained in:
parent
1f73051571
commit
987d90f6ff
4 changed files with 30 additions and 4 deletions
|
|
@ -86,6 +86,9 @@ class LLMConfig(YamlModel):
|
|||
# Cost Control
|
||||
calc_usage: bool = True
|
||||
|
||||
# Compress request messages under token limit
|
||||
compress_type: str = ""
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class BaseLLM(ABC):
|
|||
else:
|
||||
message.extend(msg)
|
||||
logger.debug(message)
|
||||
compressed_message = self.compress_messages(message, compress_type="post_cut_by_token")
|
||||
compressed_message = self.compress_messages(message, compress_type=self.config.compress_type)
|
||||
rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout))
|
||||
# rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
|
||||
return rsp
|
||||
|
|
|
|||
|
|
@ -147,6 +147,13 @@ def test_compress_messages_long(compress_type):
|
|||
assert compressed[2]["role"] != "system"
|
||||
|
||||
|
||||
def test_long_messages_no_compress():
|
||||
base_llm = MockBaseLLM()
|
||||
messages = [{"role": "user", "content": "1" * 10000}] * 10000
|
||||
compressed = base_llm.compress_messages(messages)
|
||||
assert len(compressed) == len(messages)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ def test_count_tokens():
|
|||
llm._assistant_msg("assistant 2"),
|
||||
]
|
||||
cnt = llm.count_tokens(messages)
|
||||
print(cnt)
|
||||
assert cnt == 47
|
||||
|
||||
|
||||
def test_count_tokens_long():
|
||||
|
|
@ -190,11 +190,11 @@ def test_count_tokens_long():
|
|||
llm._user_msg(test_msg_content + " what's the first number you see?"),
|
||||
]
|
||||
cnt = llm.count_tokens(messages) # 299023, ~300k
|
||||
print(cnt)
|
||||
assert 290000 <= cnt <= 300000
|
||||
|
||||
llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens
|
||||
cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference
|
||||
print(cnt)
|
||||
assert 290000 <= cnt <= 300000
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
|
|
@ -202,9 +202,25 @@ def test_count_tokens_long():
|
|||
async def test_aask_long():
|
||||
llm = LLM()
|
||||
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
|
||||
llm.config.compress_type = "post_cut_by_token"
|
||||
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
|
||||
messages = [
|
||||
llm._system_msg("You are a helpful assistant"),
|
||||
llm._user_msg(test_msg_content + " what's the first number you see?"),
|
||||
]
|
||||
await llm.aask(messages) # should not fail with context truncated
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_long_no_compress():
|
||||
llm = LLM()
|
||||
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
|
||||
# Not specifying llm.config.compress_type will use default "", no compress
|
||||
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
|
||||
messages = [
|
||||
llm._system_msg("You are a helpful assistant"),
|
||||
llm._user_msg(test_msg_content + " what's the first number you see?"),
|
||||
]
|
||||
with pytest.raises(Exception):
|
||||
await llm.aask(messages) # should fail
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue