compress type to enum

This commit is contained in:
garylin2099 2024-08-02 18:43:55 +08:00
parent 987d90f6ff
commit 25e67db1af
5 changed files with 55 additions and 25 deletions

View file

@ -0,0 +1,32 @@
from enum import Enum
class CompressType(Enum):
"""
Compression Type for messages. Used to compress messages under token limit.
- "": No compression. Default value.
- "post_cut_by_msg": Keep as many latest messages as possible.
- "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message.
- "pre_cut_by_msg": Keep as many earliest messages as possible.
- "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message.
"""
NO_COMPRESS = ""
POST_CUT_BY_MSG = "post_cut_by_msg"
POST_CUT_BY_TOKEN = "post_cut_by_token"
PRE_CUT_BY_MSG = "pre_cut_by_msg"
PRE_CUT_BY_TOKEN = "pre_cut_by_token"
def __missing__(self, key):
return self.NO_COMPRESS
@classmethod
def get_type(cls, type_name):
for member in cls:
if member.value == type_name:
return member
return cls.NO_COMPRESS
@classmethod
def cut_types(cls):
return [member for member in cls if "cut" in member.value]

View file

@ -10,6 +10,7 @@ from typing import Optional
from pydantic import field_validator
from metagpt.configs.compress_msg_config import CompressType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.utils.yaml_model import YamlModel
@ -87,7 +88,7 @@ class LLMConfig(YamlModel):
calc_usage: bool = True
# Compress request messages under token limit
compress_type: str = ""
compress_type: CompressType = CompressType.NO_COMPRESS
@field_validator("api_key")
@classmethod
@ -100,3 +101,8 @@ class LLMConfig(YamlModel):
@classmethod
def check_timeout(cls, v):
return v or LLM_API_TIMEOUT
@field_validator("compress_type")
@classmethod
def check_compress_type(cls, v):
return CompressType.get_type(v)

View file

@ -10,7 +10,7 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import Literal, Optional, Union
from typing import Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
@ -22,6 +22,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
@ -279,23 +280,18 @@ class BaseLLM(ABC):
def compress_messages(
self,
messages: list[dict],
compress_type: Literal["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token", ""] = "",
compress_type: CompressType = CompressType.NO_COMPRESS,
max_token: int = 128000,
threshold: float = 0.8,
) -> list[dict]:
"""Compress messages to fit within the token limit.
Args:
messages (list[dict]): List of messages to compress.
compress_type (str, optional): Compression strategy.
- "post_cut_by_msg": Keep as many latest messages as possible.
- "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message.
- "pre_cut_by_msg": Keep as many earliest messages as possible.
- "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message.
- "": No compression. Default value.
compress_type (CompressType, optional): Compression strategy. Defaults to CompressType.NO_COMPRESS.
max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX.
threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message.
"""
if not compress_type:
if compress_type == CompressType.NO_COMPRESS:
return messages
current_token_count = 0
@ -318,7 +314,7 @@ class BaseLLM(ABC):
compressed.extend(system_msgs)
current_token_count += self.count_tokens(system_msgs)
if compress_type in ["post_cut_by_msg", "post_cut_by_token"]:
if compress_type in [CompressType.POST_CUT_BY_TOKEN, CompressType.POST_CUT_BY_MSG]:
# Under keep_token constraint, keep as many latest messages as possible
for i, msg in enumerate(reversed(user_assistant_msgs)):
token_count = self.count_tokens([msg])
@ -326,7 +322,7 @@ class BaseLLM(ABC):
compressed.insert(len(system_msgs), msg)
current_token_count += token_count
else:
if compress_type == "post_cut_by_token" or len(compressed) == len(system_msgs):
if compress_type == CompressType.POST_CUT_BY_TOKEN or len(compressed) == len(system_msgs):
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
truncated_content = msg["content"][-(keep_token - current_token_count) :]
compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content})
@ -336,7 +332,7 @@ class BaseLLM(ABC):
)
break
elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]:
elif compress_type in [CompressType.PRE_CUT_BY_TOKEN, CompressType.PRE_CUT_BY_MSG]:
# Under keep_token constraint, keep as many earliest messages as possible
for i, msg in enumerate(user_assistant_msgs):
token_count = self.count_tokens([msg])
@ -344,7 +340,7 @@ class BaseLLM(ABC):
compressed.append(msg)
current_token_count += token_count
else:
if compress_type == "pre_cut_by_token" or len(compressed) == len(system_msgs):
if compress_type == CompressType.PRE_CUT_BY_TOKEN or len(compressed) == len(system_msgs):
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
truncated_content = msg["content"][: keep_token - current_token_count]
compressed.append({"role": msg["role"], "content": truncated_content})

View file

@ -8,6 +8,7 @@
import pytest
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
@ -106,9 +107,7 @@ async def test_async_base_llm():
# assert resp == default_resp_cont
@pytest.mark.parametrize(
"compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"]
)
@pytest.mark.parametrize("compress_type", list(CompressType))
def test_compress_messages_no_effect(compress_type):
base_llm = MockBaseLLM()
messages = [
@ -123,9 +122,7 @@ def test_compress_messages_no_effect(compress_type):
assert compressed == messages
@pytest.mark.parametrize(
"compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"]
)
@pytest.mark.parametrize("compress_type", CompressType.cut_types())
def test_compress_messages_long(compress_type):
base_llm = MockBaseLLM()
base_llm.config.model = "test_llm"
@ -142,7 +139,7 @@ def test_compress_messages_long(compress_type):
print(compressed)
print(len(compressed))
assert len(compressed) < len(messages)
assert 3 <= len(compressed) < len(messages)
assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system"
assert compressed[2]["role"] != "system"
@ -154,9 +151,7 @@ def test_long_messages_no_compress():
assert len(compressed) == len(messages)
@pytest.mark.parametrize(
"compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"]
)
@pytest.mark.parametrize("compress_type", CompressType.cut_types())
def test_compress_messages_long_no_sys_msg(compress_type):
base_llm = MockBaseLLM()
base_llm.config.model = "test_llm"

View file

@ -9,6 +9,7 @@ from openai.types.chat.chat_completion import Choice, CompletionUsage
from openai.types.chat.chat_completion_message_tool_call import Function
from PIL import Image
from metagpt.configs.compress_msg_config import CompressType
from metagpt.const import TEST_DATA_PATH
from metagpt.llm import LLM
from metagpt.logs import logger
@ -202,7 +203,7 @@ def test_count_tokens_long():
async def test_aask_long():
llm = LLM()
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
llm.config.compress_type = "post_cut_by_token"
llm.config.compress_type = CompressType.POST_CUT_BY_TOKEN
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
messages = [
llm._system_msg("You are a helpful assistant"),