compress type to enum

This commit is contained in:
garylin2099 2024-08-02 18:43:55 +08:00
parent 987d90f6ff
commit 25e67db1af
5 changed files with 55 additions and 25 deletions

View file

@ -10,7 +10,7 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import Literal, Optional, Union
from typing import Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
@ -22,6 +22,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
@ -279,23 +280,18 @@ class BaseLLM(ABC):
def compress_messages(
self,
messages: list[dict],
compress_type: Literal["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token", ""] = "",
compress_type: CompressType = CompressType.NO_COMPRESS,
max_token: int = 128000,
threshold: float = 0.8,
) -> list[dict]:
"""Compress messages to fit within the token limit.
Args:
messages (list[dict]): List of messages to compress.
compress_type (str, optional): Compression strategy.
- "post_cut_by_msg": Keep as many latest messages as possible.
- "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message.
- "pre_cut_by_msg": Keep as many earliest messages as possible.
- "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message.
- "": No compression. Default value.
compress_type (CompressType, optional): Compression strategy. Defaults to CompressType.NO_COMPRESS.
max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX.
threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message.
"""
if not compress_type:
if compress_type == CompressType.NO_COMPRESS:
return messages
current_token_count = 0
@ -318,7 +314,7 @@ class BaseLLM(ABC):
compressed.extend(system_msgs)
current_token_count += self.count_tokens(system_msgs)
if compress_type in ["post_cut_by_msg", "post_cut_by_token"]:
if compress_type in [CompressType.POST_CUT_BY_TOKEN, CompressType.POST_CUT_BY_MSG]:
# Under keep_token constraint, keep as many latest messages as possible
for i, msg in enumerate(reversed(user_assistant_msgs)):
token_count = self.count_tokens([msg])
@ -326,7 +322,7 @@ class BaseLLM(ABC):
compressed.insert(len(system_msgs), msg)
current_token_count += token_count
else:
if compress_type == "post_cut_by_token" or len(compressed) == len(system_msgs):
if compress_type == CompressType.POST_CUT_BY_TOKEN or len(compressed) == len(system_msgs):
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
truncated_content = msg["content"][-(keep_token - current_token_count) :]
compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content})
@ -336,7 +332,7 @@ class BaseLLM(ABC):
)
break
elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]:
elif compress_type in [CompressType.PRE_CUT_BY_TOKEN, CompressType.PRE_CUT_BY_MSG]:
# Under keep_token constraint, keep as many earliest messages as possible
for i, msg in enumerate(user_assistant_msgs):
token_count = self.count_tokens([msg])
@ -344,7 +340,7 @@ class BaseLLM(ABC):
compressed.append(msg)
current_token_count += token_count
else:
if compress_type == "pre_cut_by_token" or len(compressed) == len(system_msgs):
if compress_type == CompressType.PRE_CUT_BY_TOKEN or len(compressed) == len(system_msgs):
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
truncated_content = msg["content"][: keep_token - current_token_count]
compressed.append({"role": msg["role"], "content": truncated_content})