diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 48130eedc..657c1a70b 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -86,6 +86,9 @@ class LLMConfig(YamlModel): # Cost Control calc_usage: bool = True + # Compress request messages under token limit + compress_type: str = "" + @field_validator("api_key") @classmethod def check_llm_key(cls, v): diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 68b1f78a6..6b544375e 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -148,7 +148,7 @@ class BaseLLM(ABC): else: message.extend(msg) logger.debug(message) - compressed_message = self.compress_messages(message, compress_type="post_cut_by_token") + compressed_message = self.compress_messages(message, compress_type=self.config.compress_type) rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout)) # rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) return rsp diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 250c3b5c4..87e4bdbbb 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -147,6 +147,13 @@ def test_compress_messages_long(compress_type): assert compressed[2]["role"] != "system" +def test_long_messages_no_compress(): + base_llm = MockBaseLLM() + messages = [{"role": "user", "content": "1" * 10000}] * 10000 + compressed = base_llm.compress_messages(messages) + assert len(compressed) == len(messages) + + @pytest.mark.parametrize( "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] ) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 5f99e2b52..dde98f7a3 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -178,7 +178,7 @@ def test_count_tokens(): llm._assistant_msg("assistant 2"), ] cnt = llm.count_tokens(messages) - print(cnt) + assert cnt == 47 def test_count_tokens_long(): @@ -190,11 +190,11 @@ def test_count_tokens_long(): llm._user_msg(test_msg_content + " what's the first number you see?"), ] cnt = llm.count_tokens(messages) # 299023, ~300k - print(cnt) + assert 290000 <= cnt <= 300000 llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference - print(cnt) + assert 290000 <= cnt <= 300000 @pytest.mark.skip @@ -202,9 +202,25 @@ def test_count_tokens_long(): async def test_aask_long(): llm = LLM() llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + llm.config.compress_type = "post_cut_by_token" test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens messages = [ llm._system_msg("You are a helpful assistant"), llm._user_msg(test_msg_content + " what's the first number you see?"), ] await llm.aask(messages) # should not fail with context truncated + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_aask_long_no_compress(): + llm = LLM() + llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + # Not specifying llm.config.compress_type will use default "", no compress + test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens + messages = [ + llm._system_msg("You are a helpful assistant"), + llm._user_msg(test_msg_content + " what's the first number you see?"), + ] + with pytest.raises(Exception): + await llm.aask(messages) # should fail