diff --git a/metagpt/configs/compress_msg_config.py b/metagpt/configs/compress_msg_config.py new file mode 100644 index 000000000..c46334c12 --- /dev/null +++ b/metagpt/configs/compress_msg_config.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class CompressType(Enum): + """ + Compression Type for messages. Used to compress messages under token limit. + - "": No compression. Default value. + - "post_cut_by_msg": Keep as many latest messages as possible. + - "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message. + - "pre_cut_by_msg": Keep as many earliest messages as possible. + - "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message. + """ + + NO_COMPRESS = "" + POST_CUT_BY_MSG = "post_cut_by_msg" + POST_CUT_BY_TOKEN = "post_cut_by_token" + PRE_CUT_BY_MSG = "pre_cut_by_msg" + PRE_CUT_BY_TOKEN = "pre_cut_by_token" + + def __missing__(self, key): + return self.NO_COMPRESS + + @classmethod + def get_type(cls, type_name): + for member in cls: + if member.value == type_name: + return member + return cls.NO_COMPRESS + + @classmethod + def cut_types(cls): + return [member for member in cls if "cut" in member.value] diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 657c1a70b..c5605be21 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -10,6 +10,7 @@ from typing import Optional from pydantic import field_validator +from metagpt.configs.compress_msg_config import CompressType from metagpt.const import LLM_API_TIMEOUT from metagpt.utils.yaml_model import YamlModel @@ -87,7 +88,7 @@ class LLMConfig(YamlModel): calc_usage: bool = True # Compress request messages under token limit - compress_type: str = "" + compress_type: CompressType = CompressType.NO_COMPRESS @field_validator("api_key") @classmethod @@ -100,3 +101,8 @@ class LLMConfig(YamlModel): @classmethod def check_timeout(cls, v): return v or LLM_API_TIMEOUT + + @field_validator("compress_type") + @classmethod + def check_compress_type(cls, v): + return CompressType.get_type(v) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 6b544375e..ac09c19f7 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -10,7 +10,7 @@ from __future__ import annotations import json from abc import ABC, abstractmethod -from typing import Literal, Optional, Union +from typing import Optional, Union from openai import AsyncOpenAI from pydantic import BaseModel @@ -22,6 +22,7 @@ from tenacity import ( wait_random_exponential, ) +from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger @@ -279,23 +280,18 @@ class BaseLLM(ABC): def compress_messages( self, messages: list[dict], - compress_type: Literal["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token", ""] = "", + compress_type: CompressType = CompressType.NO_COMPRESS, max_token: int = 128000, threshold: float = 0.8, ) -> list[dict]: """Compress messages to fit within the token limit. Args: messages (list[dict]): List of messages to compress. - compress_type (str, optional): Compression strategy. - - "post_cut_by_msg": Keep as many latest messages as possible. - - "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message. - - "pre_cut_by_msg": Keep as many earliest messages as possible. - - "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message. - - "": No compression. Default value. + compress_type (CompressType, optional): Compression strategy. Defaults to CompressType.NO_COMPRESS. max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX. threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message. """ - if not compress_type: + if compress_type == CompressType.NO_COMPRESS: return messages current_token_count = 0 @@ -318,7 +314,7 @@ class BaseLLM(ABC): compressed.extend(system_msgs) current_token_count += self.count_tokens(system_msgs) - if compress_type in ["post_cut_by_msg", "post_cut_by_token"]: + if compress_type in [CompressType.POST_CUT_BY_TOKEN, CompressType.POST_CUT_BY_MSG]: # Under keep_token constraint, keep as many latest messages as possible for i, msg in enumerate(reversed(user_assistant_msgs)): token_count = self.count_tokens([msg]) @@ -326,7 +322,7 @@ class BaseLLM(ABC): compressed.insert(len(system_msgs), msg) current_token_count += token_count else: - if compress_type == "post_cut_by_token" or len(compressed) == len(system_msgs): + if compress_type == CompressType.POST_CUT_BY_TOKEN or len(compressed) == len(system_msgs): # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][-(keep_token - current_token_count) :] compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content}) @@ -336,7 +332,7 @@ class BaseLLM(ABC): ) break - elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]: + elif compress_type in [CompressType.PRE_CUT_BY_TOKEN, CompressType.PRE_CUT_BY_MSG]: # Under keep_token constraint, keep as many earliest messages as possible for i, msg in enumerate(user_assistant_msgs): token_count = self.count_tokens([msg]) @@ -344,7 +340,7 @@ class BaseLLM(ABC): compressed.append(msg) current_token_count += token_count else: - if compress_type == "pre_cut_by_token" or len(compressed) == len(system_msgs): + if compress_type == CompressType.PRE_CUT_BY_TOKEN or len(compressed) == len(system_msgs): # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][: keep_token - current_token_count] compressed.append({"role": msg["role"], "content": truncated_content}) diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 87e4bdbbb..d34ed62f1 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -8,6 +8,7 @@ import pytest +from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message @@ -106,9 +107,7 @@ async def test_async_base_llm(): # assert resp == default_resp_cont -@pytest.mark.parametrize( - "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] -) +@pytest.mark.parametrize("compress_type", list(CompressType)) def test_compress_messages_no_effect(compress_type): base_llm = MockBaseLLM() messages = [ @@ -123,9 +122,7 @@ def test_compress_messages_no_effect(compress_type): assert compressed == messages -@pytest.mark.parametrize( - "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] -) +@pytest.mark.parametrize("compress_type", CompressType.cut_types()) def test_compress_messages_long(compress_type): base_llm = MockBaseLLM() base_llm.config.model = "test_llm" @@ -142,7 +139,7 @@ def test_compress_messages_long(compress_type): print(compressed) print(len(compressed)) - assert len(compressed) < len(messages) + assert 3 <= len(compressed) < len(messages) assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system" assert compressed[2]["role"] != "system" @@ -154,9 +151,7 @@ def test_long_messages_no_compress(): assert len(compressed) == len(messages) -@pytest.mark.parametrize( - "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] -) +@pytest.mark.parametrize("compress_type", CompressType.cut_types()) def test_compress_messages_long_no_sys_msg(compress_type): base_llm = MockBaseLLM() base_llm.config.model = "test_llm" diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index dde98f7a3..d292a8286 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -9,6 +9,7 @@ from openai.types.chat.chat_completion import Choice, CompletionUsage from openai.types.chat.chat_completion_message_tool_call import Function from PIL import Image +from metagpt.configs.compress_msg_config import CompressType from metagpt.const import TEST_DATA_PATH from metagpt.llm import LLM from metagpt.logs import logger @@ -202,7 +203,7 @@ def test_count_tokens_long(): async def test_aask_long(): llm = LLM() llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k - llm.config.compress_type = "post_cut_by_token" + llm.config.compress_type = CompressType.POST_CUT_BY_TOKEN test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens messages = [ llm._system_msg("You are a helpful assistant"),