compress type to enum

This commit is contained in:
garylin2099 2024-08-02 18:43:55 +08:00
parent 987d90f6ff
commit 25e67db1af
5 changed files with 55 additions and 25 deletions

View file

@ -0,0 +1,32 @@
from enum import Enum
class CompressType(Enum):
"""
Compression Type for messages. Used to compress messages under token limit.
- "": No compression. Default value.
- "post_cut_by_msg": Keep as many latest messages as possible.
- "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message.
- "pre_cut_by_msg": Keep as many earliest messages as possible.
- "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message.
"""
NO_COMPRESS = ""
POST_CUT_BY_MSG = "post_cut_by_msg"
POST_CUT_BY_TOKEN = "post_cut_by_token"
PRE_CUT_BY_MSG = "pre_cut_by_msg"
PRE_CUT_BY_TOKEN = "pre_cut_by_token"
def __missing__(self, key):
return self.NO_COMPRESS
@classmethod
def get_type(cls, type_name):
for member in cls:
if member.value == type_name:
return member
return cls.NO_COMPRESS
@classmethod
def cut_types(cls):
return [member for member in cls if "cut" in member.value]

View file

@ -10,6 +10,7 @@ from typing import Optional
from pydantic import field_validator
from metagpt.configs.compress_msg_config import CompressType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.utils.yaml_model import YamlModel
@ -87,7 +88,7 @@ class LLMConfig(YamlModel):
calc_usage: bool = True
# Compress request messages under token limit
compress_type: str = ""
compress_type: CompressType = CompressType.NO_COMPRESS
@field_validator("api_key")
@classmethod
@ -100,3 +101,8 @@ class LLMConfig(YamlModel):
@classmethod
def check_timeout(cls, v):
return v or LLM_API_TIMEOUT
@field_validator("compress_type")
@classmethod
def check_compress_type(cls, v):
return CompressType.get_type(v)