From a6104f39310dffdec04b042028d2aad7ad740113 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Wed, 31 Jul 2024 22:07:48 +0800 Subject: [PATCH 1/4] first version of long context compression --- metagpt/provider/base_llm.py | 97 ++++++++++++++++++++++++- metagpt/provider/openai_api.py | 6 ++ metagpt/utils/token_counter.py | 3 +- tests/metagpt/provider/test_base_llm.py | 43 +++++++++++ tests/metagpt/provider/test_openai.py | 44 +++++++++++ 5 files changed, 190 insertions(+), 3 deletions(-) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 4489c56c5..4c7bb4cbc 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -10,7 +10,7 @@ from __future__ import annotations import json from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import Literal, Optional, Union from openai import AsyncOpenAI from pydantic import BaseModel @@ -28,6 +28,7 @@ from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.token_counter import TOKEN_MAX class BaseLLM(ABC): @@ -147,7 +148,9 @@ class BaseLLM(ABC): else: message.extend(msg) logger.debug(message) - rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) + compressed_message = self.compress_messages(message, compress_type="post_cut_by_token") + rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout)) + # rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) return rsp def _extract_assistant_rsp(self, context): @@ -264,3 +267,93 @@ class BaseLLM(ABC): def get_timeout(self, timeout: int) -> int: return timeout or self.config.timeout or LLM_API_TIMEOUT + + def count_tokens(self, messages: list[dict]) -> int: + # A very raw heuristic to count tokens, taking reference from: + # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them + # https://platform.deepseek.com/api-docs/#token--token-usage + # The heuristics is a huge overestimate for English text, e.g., and should be overwrittem with accurate token count function in inherited class + # logger.warning("Base count_tokens is not accurate and should be overwritten.") + return sum([int(len(msg["content"]) * 0.5) for msg in messages]) + + def compress_messages( + self, + messages: list[dict], + compress_type: Literal["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token", ""] = "", + max_token: int = 128000, + threshold: float = 0.8, + ) -> list[dict]: + """Compress messages to fit within the token limit. + Args: + messages (list[dict]): List of messages to compress. + compress_type (str, optional): Compression strategy. + - "post_cut_by_msg": Keep as many latest messages as possible. + - "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message. + - "pre_cut_by_msg": Keep as many earliest messages as possible. + - "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message. + - "": No compression. Default value. + max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX. + threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message. + """ + if not compress_type: + return messages + + current_token_count = 0 + max_token = TOKEN_MAX.get(self.config.model, max_token) + keep_token = int(max_token * threshold) + compressed = [] + + # Always keep system messages and assume they do not exceed token limit + system_msg_val = self._system_msg("")["role"] + system_msgs = [] + for i, msg in enumerate(messages): + if msg["role"] == system_msg_val: + system_msgs.append(msg) + else: + user_assistant_msgs = messages[i:] + break + # system_msgs = [msg for msg in messages if msg["role"] == system_msg_val] + # user_assistant_msgs = [msg for msg in messages if msg["role"] != system_msg_val] + compressed.extend(system_msgs) + current_token_count += self.count_tokens(system_msgs) + + if compress_type in ["post_cut_by_msg", "post_cut_by_token"]: + # Under keep_token constraint, keep as many latest messages as possible + for msg in reversed(user_assistant_msgs): + token_count = self.count_tokens([msg]) + if current_token_count + token_count <= keep_token: + compressed.insert(len(system_msgs), msg) + current_token_count += token_count + else: + truncated_msg_idx = len(system_msgs) - 1 + if compress_type == "post_cut_by_token": + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg + truncated_content = msg["content"][-(keep_token - current_token_count) :] + compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content}) + # after post truncation, the first message after the system message is the one truncated + truncated_msg_idx = len(system_msgs) + logger.warning( + f"Truncated messages with {compress_type} to fit within the token limit. " + f"The first user or assistant message after truncation is: {compressed[truncated_msg_idx]}." + ) + break + + elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]: + # Under keep_token constraint, keep as many earliest messages as possible + for msg in user_assistant_msgs: + token_count = self.count_tokens([msg]) + if current_token_count + token_count <= keep_token: + compressed.append(msg) + current_token_count += token_count + else: + if compress_type == "pre_cut_by_token": + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg + truncated_content = msg["content"][: keep_token - current_token_count] + compressed.append({"role": msg["role"], "content": truncated_content}) + logger.warning( + f"Truncated messages with {compress_type} to fit within the token limit. " + f"The last message after truncation is: {compressed[-1]}." + ) + break + + return compressed diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index fa689d54f..11b1d38cb 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -300,3 +300,9 @@ class OpenAILLM(BaseLLM): img_url_or_b64 = item.url if resp_format == "url" else item.b64_json imgs.append(decode_image(img_url_or_b64)) return imgs + + def count_tokens(self, messages: list[dict]) -> int: + try: + return count_message_tokens(messages, self.config.model) + except: + return super().count_tokens(messages) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 5be6b5f61..b235ceb7b 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -215,6 +215,7 @@ TOKEN_MAX = { "deepseek/deepseek-chat": 128000, # end, for openrouter "deepseek-chat": 128000, "deepseek-coder": 128000, + "deepseek-ai/DeepSeek-Coder-V2-Instruct": 32000, # siliconflow } @@ -319,4 +320,4 @@ def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> """ if model not in TOKEN_MAX: return default - return TOKEN_MAX[model] - count_message_tokens(messages) - 1 + return TOKEN_MAX[model] - count_message_tokens(messages, model) - 1 diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 40a9fda92..c9c13650a 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -41,6 +41,7 @@ class MockBaseLLM(BaseLLM): return default_resp_cont +@pytest.mark.skip def test_base_llm(): message = Message(role="user", content="hello") assert "role" in message.to_dict() @@ -92,6 +93,7 @@ def test_base_llm(): # assert resp == default_resp_cont +@pytest.mark.skip @pytest.mark.asyncio async def test_async_base_llm(): base_llm = MockBaseLLM() @@ -104,3 +106,44 @@ async def test_async_base_llm(): # resp = await base_llm.aask_code([prompt]) # assert resp == default_resp_cont + + +@pytest.mark.parametrize( + "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] +) +def test_compress_messages_no_effect(compress_type): + base_llm = MockBaseLLM() + messages = [ + {"role": "system", "content": "first system msg"}, + {"role": "system", "content": "second system msg"}, + ] + for i in range(5): + messages.append({"role": "user", "content": f"u{i}"}) + messages.append({"role": "assistant", "content": f"a{i}"}) + compressed = base_llm.compress_messages(messages, compress_type=compress_type) + # should take no effect for short context + assert compressed == messages + + +@pytest.mark.parametrize( + "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] +) +def test_compress_messages_long(compress_type): + base_llm = MockBaseLLM() + base_llm.config.model = "test_llm" + max_token_limit = 100 + + messages = [ + {"role": "system", "content": "first system msg"}, + {"role": "system", "content": "second system msg"}, + ] + for i in range(100): + messages.append({"role": "user", "content": f"u{i}" * 10}) # ~2x10x0.5 = 10 tokens + messages.append({"role": "assistant", "content": f"a{i}" * 10}) + compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit) + + print(compressed) + print(len(compressed)) + assert len(compressed) < len(messages) + assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system" + assert compressed[2]["role"] != "system" diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 3ce38d2a5..5f99e2b52 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -164,3 +164,47 @@ async def test_openai_acompletion(mocker): assert resp.usage == usage await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont) + + +def test_count_tokens(): + llm = LLM() + llm.config.model = "gpt-4o" + messages = [ + llm._system_msg("some system msg"), + llm._system_msg("some system message 2"), + llm._user_msg("user 1"), + llm._assistant_msg("assistant 1"), + llm._user_msg("user 1"), + llm._assistant_msg("assistant 2"), + ] + cnt = llm.count_tokens(messages) + print(cnt) + + +def test_count_tokens_long(): + llm = LLM() + llm.config.model = "gpt-4-0613" + test_msg_content = " ".join([str(i) for i in range(100000)]) + messages = [ + llm._system_msg("You are a helpful assistant"), + llm._user_msg(test_msg_content + " what's the first number you see?"), + ] + cnt = llm.count_tokens(messages) # 299023, ~300k + print(cnt) + + llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens + cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference + print(cnt) + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_aask_long(): + llm = LLM() + llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens + messages = [ + llm._system_msg("You are a helpful assistant"), + llm._user_msg(test_msg_content + " what's the first number you see?"), + ] + await llm.aask(messages) # should not fail with context truncated From 1f73051571e03cdb3d0b62f24cf4023c6375e492 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Wed, 31 Jul 2024 23:22:49 +0800 Subject: [PATCH 2/4] modify detail --- metagpt/provider/base_llm.py | 22 ++++++++++------------ tests/metagpt/provider/test_base_llm.py | 18 ++++++++++++++++-- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 4c7bb4cbc..68b1f78a6 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -303,7 +303,8 @@ class BaseLLM(ABC): keep_token = int(max_token * threshold) compressed = [] - # Always keep system messages and assume they do not exceed token limit + # Always keep system messages + # NOTE: Assume they do not exceed token limit system_msg_val = self._system_msg("")["role"] system_msgs = [] for i, msg in enumerate(messages): @@ -319,40 +320,37 @@ class BaseLLM(ABC): if compress_type in ["post_cut_by_msg", "post_cut_by_token"]: # Under keep_token constraint, keep as many latest messages as possible - for msg in reversed(user_assistant_msgs): + for i, msg in enumerate(reversed(user_assistant_msgs)): token_count = self.count_tokens([msg]) if current_token_count + token_count <= keep_token: compressed.insert(len(system_msgs), msg) current_token_count += token_count else: - truncated_msg_idx = len(system_msgs) - 1 - if compress_type == "post_cut_by_token": - # Truncate the message to fit within the remaining token count; Otherwise, discard the msg + if compress_type == "post_cut_by_token" or len(compressed) == len(system_msgs): + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][-(keep_token - current_token_count) :] compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content}) - # after post truncation, the first message after the system message is the one truncated - truncated_msg_idx = len(system_msgs) logger.warning( f"Truncated messages with {compress_type} to fit within the token limit. " - f"The first user or assistant message after truncation is: {compressed[truncated_msg_idx]}." + f"The first user or assistant message after truncation (originally the {i}-th message from last): {compressed[len(system_msgs)]}." ) break elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]: # Under keep_token constraint, keep as many earliest messages as possible - for msg in user_assistant_msgs: + for i, msg in enumerate(user_assistant_msgs): token_count = self.count_tokens([msg]) if current_token_count + token_count <= keep_token: compressed.append(msg) current_token_count += token_count else: - if compress_type == "pre_cut_by_token": - # Truncate the message to fit within the remaining token count; Otherwise, discard the msg + if compress_type == "pre_cut_by_token" or len(compressed) == len(system_msgs): + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][: keep_token - current_token_count] compressed.append({"role": msg["role"], "content": truncated_content}) logger.warning( f"Truncated messages with {compress_type} to fit within the token limit. " - f"The last message after truncation is: {compressed[-1]}." + f"The last user or assistant message after truncation (originally the {i}-th message): {compressed[-1]}." ) break diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index c9c13650a..250c3b5c4 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -41,7 +41,6 @@ class MockBaseLLM(BaseLLM): return default_resp_cont -@pytest.mark.skip def test_base_llm(): message = Message(role="user", content="hello") assert "role" in message.to_dict() @@ -93,7 +92,6 @@ def test_base_llm(): # assert resp == default_resp_cont -@pytest.mark.skip @pytest.mark.asyncio async def test_async_base_llm(): base_llm = MockBaseLLM() @@ -147,3 +145,19 @@ def test_compress_messages_long(compress_type): assert len(compressed) < len(messages) assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system" assert compressed[2]["role"] != "system" + + +@pytest.mark.parametrize( + "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] +) +def test_compress_messages_long_no_sys_msg(compress_type): + base_llm = MockBaseLLM() + base_llm.config.model = "test_llm" + max_token_limit = 100 + + messages = [{"role": "user", "content": "1" * 10000}] + compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit) + + print(compressed) + assert compressed + assert len(compressed[0]["content"]) < len(messages[0]["content"]) From 987d90f6ff8e6872b6954f976dd2c1f41b13c377 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Fri, 2 Aug 2024 16:43:27 +0800 Subject: [PATCH 3/4] add compress_type to config --- metagpt/configs/llm_config.py | 3 +++ metagpt/provider/base_llm.py | 2 +- tests/metagpt/provider/test_base_llm.py | 7 +++++++ tests/metagpt/provider/test_openai.py | 22 +++++++++++++++++++--- 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 48130eedc..657c1a70b 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -86,6 +86,9 @@ class LLMConfig(YamlModel): # Cost Control calc_usage: bool = True + # Compress request messages under token limit + compress_type: str = "" + @field_validator("api_key") @classmethod def check_llm_key(cls, v): diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 68b1f78a6..6b544375e 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -148,7 +148,7 @@ class BaseLLM(ABC): else: message.extend(msg) logger.debug(message) - compressed_message = self.compress_messages(message, compress_type="post_cut_by_token") + compressed_message = self.compress_messages(message, compress_type=self.config.compress_type) rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout)) # rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) return rsp diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 250c3b5c4..87e4bdbbb 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -147,6 +147,13 @@ def test_compress_messages_long(compress_type): assert compressed[2]["role"] != "system" +def test_long_messages_no_compress(): + base_llm = MockBaseLLM() + messages = [{"role": "user", "content": "1" * 10000}] * 10000 + compressed = base_llm.compress_messages(messages) + assert len(compressed) == len(messages) + + @pytest.mark.parametrize( "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] ) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 5f99e2b52..dde98f7a3 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -178,7 +178,7 @@ def test_count_tokens(): llm._assistant_msg("assistant 2"), ] cnt = llm.count_tokens(messages) - print(cnt) + assert cnt == 47 def test_count_tokens_long(): @@ -190,11 +190,11 @@ def test_count_tokens_long(): llm._user_msg(test_msg_content + " what's the first number you see?"), ] cnt = llm.count_tokens(messages) # 299023, ~300k - print(cnt) + assert 290000 <= cnt <= 300000 llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference - print(cnt) + assert 290000 <= cnt <= 300000 @pytest.mark.skip @@ -202,9 +202,25 @@ def test_count_tokens_long(): async def test_aask_long(): llm = LLM() llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + llm.config.compress_type = "post_cut_by_token" test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens messages = [ llm._system_msg("You are a helpful assistant"), llm._user_msg(test_msg_content + " what's the first number you see?"), ] await llm.aask(messages) # should not fail with context truncated + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_aask_long_no_compress(): + llm = LLM() + llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k + # Not specifying llm.config.compress_type will use default "", no compress + test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens + messages = [ + llm._system_msg("You are a helpful assistant"), + llm._user_msg(test_msg_content + " what's the first number you see?"), + ] + with pytest.raises(Exception): + await llm.aask(messages) # should fail From 25e67db1afa77eeb661006948d657a73c11778a9 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Fri, 2 Aug 2024 18:43:55 +0800 Subject: [PATCH 4/4] compress type to enum --- metagpt/configs/compress_msg_config.py | 32 +++++++++++++++++++++++++ metagpt/configs/llm_config.py | 8 ++++++- metagpt/provider/base_llm.py | 22 +++++++---------- tests/metagpt/provider/test_base_llm.py | 15 ++++-------- tests/metagpt/provider/test_openai.py | 3 ++- 5 files changed, 55 insertions(+), 25 deletions(-) create mode 100644 metagpt/configs/compress_msg_config.py diff --git a/metagpt/configs/compress_msg_config.py b/metagpt/configs/compress_msg_config.py new file mode 100644 index 000000000..c46334c12 --- /dev/null +++ b/metagpt/configs/compress_msg_config.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class CompressType(Enum): + """ + Compression Type for messages. Used to compress messages under token limit. + - "": No compression. Default value. + - "post_cut_by_msg": Keep as many latest messages as possible. + - "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message. + - "pre_cut_by_msg": Keep as many earliest messages as possible. + - "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message. + """ + + NO_COMPRESS = "" + POST_CUT_BY_MSG = "post_cut_by_msg" + POST_CUT_BY_TOKEN = "post_cut_by_token" + PRE_CUT_BY_MSG = "pre_cut_by_msg" + PRE_CUT_BY_TOKEN = "pre_cut_by_token" + + def __missing__(self, key): + return self.NO_COMPRESS + + @classmethod + def get_type(cls, type_name): + for member in cls: + if member.value == type_name: + return member + return cls.NO_COMPRESS + + @classmethod + def cut_types(cls): + return [member for member in cls if "cut" in member.value] diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 657c1a70b..c5605be21 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -10,6 +10,7 @@ from typing import Optional from pydantic import field_validator +from metagpt.configs.compress_msg_config import CompressType from metagpt.const import LLM_API_TIMEOUT from metagpt.utils.yaml_model import YamlModel @@ -87,7 +88,7 @@ class LLMConfig(YamlModel): calc_usage: bool = True # Compress request messages under token limit - compress_type: str = "" + compress_type: CompressType = CompressType.NO_COMPRESS @field_validator("api_key") @classmethod @@ -100,3 +101,8 @@ class LLMConfig(YamlModel): @classmethod def check_timeout(cls, v): return v or LLM_API_TIMEOUT + + @field_validator("compress_type") + @classmethod + def check_compress_type(cls, v): + return CompressType.get_type(v) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 6b544375e..ac09c19f7 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -10,7 +10,7 @@ from __future__ import annotations import json from abc import ABC, abstractmethod -from typing import Literal, Optional, Union +from typing import Optional, Union from openai import AsyncOpenAI from pydantic import BaseModel @@ -22,6 +22,7 @@ from tenacity import ( wait_random_exponential, ) +from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger @@ -279,23 +280,18 @@ class BaseLLM(ABC): def compress_messages( self, messages: list[dict], - compress_type: Literal["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token", ""] = "", + compress_type: CompressType = CompressType.NO_COMPRESS, max_token: int = 128000, threshold: float = 0.8, ) -> list[dict]: """Compress messages to fit within the token limit. Args: messages (list[dict]): List of messages to compress. - compress_type (str, optional): Compression strategy. - - "post_cut_by_msg": Keep as many latest messages as possible. - - "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message. - - "pre_cut_by_msg": Keep as many earliest messages as possible. - - "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message. - - "": No compression. Default value. + compress_type (CompressType, optional): Compression strategy. Defaults to CompressType.NO_COMPRESS. max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX. threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message. """ - if not compress_type: + if compress_type == CompressType.NO_COMPRESS: return messages current_token_count = 0 @@ -318,7 +314,7 @@ class BaseLLM(ABC): compressed.extend(system_msgs) current_token_count += self.count_tokens(system_msgs) - if compress_type in ["post_cut_by_msg", "post_cut_by_token"]: + if compress_type in [CompressType.POST_CUT_BY_TOKEN, CompressType.POST_CUT_BY_MSG]: # Under keep_token constraint, keep as many latest messages as possible for i, msg in enumerate(reversed(user_assistant_msgs)): token_count = self.count_tokens([msg]) @@ -326,7 +322,7 @@ class BaseLLM(ABC): compressed.insert(len(system_msgs), msg) current_token_count += token_count else: - if compress_type == "post_cut_by_token" or len(compressed) == len(system_msgs): + if compress_type == CompressType.POST_CUT_BY_TOKEN or len(compressed) == len(system_msgs): # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][-(keep_token - current_token_count) :] compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content}) @@ -336,7 +332,7 @@ class BaseLLM(ABC): ) break - elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]: + elif compress_type in [CompressType.PRE_CUT_BY_TOKEN, CompressType.PRE_CUT_BY_MSG]: # Under keep_token constraint, keep as many earliest messages as possible for i, msg in enumerate(user_assistant_msgs): token_count = self.count_tokens([msg]) @@ -344,7 +340,7 @@ class BaseLLM(ABC): compressed.append(msg) current_token_count += token_count else: - if compress_type == "pre_cut_by_token" or len(compressed) == len(system_msgs): + if compress_type == CompressType.PRE_CUT_BY_TOKEN or len(compressed) == len(system_msgs): # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][: keep_token - current_token_count] compressed.append({"role": msg["role"], "content": truncated_content}) diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 87e4bdbbb..d34ed62f1 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -8,6 +8,7 @@ import pytest +from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message @@ -106,9 +107,7 @@ async def test_async_base_llm(): # assert resp == default_resp_cont -@pytest.mark.parametrize( - "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] -) +@pytest.mark.parametrize("compress_type", list(CompressType)) def test_compress_messages_no_effect(compress_type): base_llm = MockBaseLLM() messages = [ @@ -123,9 +122,7 @@ def test_compress_messages_no_effect(compress_type): assert compressed == messages -@pytest.mark.parametrize( - "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] -) +@pytest.mark.parametrize("compress_type", CompressType.cut_types()) def test_compress_messages_long(compress_type): base_llm = MockBaseLLM() base_llm.config.model = "test_llm" @@ -142,7 +139,7 @@ def test_compress_messages_long(compress_type): print(compressed) print(len(compressed)) - assert len(compressed) < len(messages) + assert 3 <= len(compressed) < len(messages) assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system" assert compressed[2]["role"] != "system" @@ -154,9 +151,7 @@ def test_long_messages_no_compress(): assert len(compressed) == len(messages) -@pytest.mark.parametrize( - "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] -) +@pytest.mark.parametrize("compress_type", CompressType.cut_types()) def test_compress_messages_long_no_sys_msg(compress_type): base_llm = MockBaseLLM() base_llm.config.model = "test_llm" diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index dde98f7a3..d292a8286 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -9,6 +9,7 @@ from openai.types.chat.chat_completion import Choice, CompletionUsage from openai.types.chat.chat_completion_message_tool_call import Function from PIL import Image +from metagpt.configs.compress_msg_config import CompressType from metagpt.const import TEST_DATA_PATH from metagpt.llm import LLM from metagpt.logs import logger @@ -202,7 +203,7 @@ def test_count_tokens_long(): async def test_aask_long(): llm = LLM() llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k - llm.config.compress_type = "post_cut_by_token" + llm.config.compress_type = CompressType.POST_CUT_BY_TOKEN test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens messages = [ llm._system_msg("You are a helpful assistant"),