Merge branch 'truncate_msg' into 'mgx_ops'

长文本退化策略:几种简单截断

See merge request pub/MetaGPT!268
This commit is contained in:
王金淋 2024-08-02 17:54:40 +00:00
commit f6be0014b7
7 changed files with 257 additions and 2 deletions

View file

@ -0,0 +1,32 @@
from enum import Enum
class CompressType(Enum):
"""
Compression Type for messages. Used to compress messages under token limit.
- "": No compression. Default value.
- "post_cut_by_msg": Keep as many latest messages as possible.
- "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message.
- "pre_cut_by_msg": Keep as many earliest messages as possible.
- "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message.
"""
NO_COMPRESS = ""
POST_CUT_BY_MSG = "post_cut_by_msg"
POST_CUT_BY_TOKEN = "post_cut_by_token"
PRE_CUT_BY_MSG = "pre_cut_by_msg"
PRE_CUT_BY_TOKEN = "pre_cut_by_token"
def __missing__(self, key):
return self.NO_COMPRESS
@classmethod
def get_type(cls, type_name):
for member in cls:
if member.value == type_name:
return member
return cls.NO_COMPRESS
@classmethod
def cut_types(cls):
return [member for member in cls if "cut" in member.value]

View file

@ -10,6 +10,7 @@ from typing import Optional
from pydantic import field_validator
from metagpt.configs.compress_msg_config import CompressType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.utils.yaml_model import YamlModel
@ -86,6 +87,9 @@ class LLMConfig(YamlModel):
# Cost Control
calc_usage: bool = True
# Compress request messages under token limit
compress_type: CompressType = CompressType.NO_COMPRESS
@field_validator("api_key")
@classmethod
def check_llm_key(cls, v):
@ -97,3 +101,8 @@ class LLMConfig(YamlModel):
@classmethod
def check_timeout(cls, v):
return v or LLM_API_TIMEOUT
@field_validator("compress_type")
@classmethod
def check_compress_type(cls, v):
return CompressType.get_type(v)

View file

@ -22,12 +22,14 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.token_counter import TOKEN_MAX
class BaseLLM(ABC):
@ -147,7 +149,9 @@ class BaseLLM(ABC):
else:
message.extend(msg)
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
compressed_message = self.compress_messages(message, compress_type=self.config.compress_type)
rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout))
# rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp
def _extract_assistant_rsp(self, context):
@ -264,3 +268,86 @@ class BaseLLM(ABC):
def get_timeout(self, timeout: int) -> int:
return timeout or self.config.timeout or LLM_API_TIMEOUT
def count_tokens(self, messages: list[dict]) -> int:
# A very raw heuristic to count tokens, taking reference from:
# https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
# https://platform.deepseek.com/api-docs/#token--token-usage
# The heuristics is a huge overestimate for English text, e.g., and should be overwrittem with accurate token count function in inherited class
# logger.warning("Base count_tokens is not accurate and should be overwritten.")
return sum([int(len(msg["content"]) * 0.5) for msg in messages])
def compress_messages(
self,
messages: list[dict],
compress_type: CompressType = CompressType.NO_COMPRESS,
max_token: int = 128000,
threshold: float = 0.8,
) -> list[dict]:
"""Compress messages to fit within the token limit.
Args:
messages (list[dict]): List of messages to compress.
compress_type (CompressType, optional): Compression strategy. Defaults to CompressType.NO_COMPRESS.
max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX.
threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message.
"""
if compress_type == CompressType.NO_COMPRESS:
return messages
current_token_count = 0
max_token = TOKEN_MAX.get(self.config.model, max_token)
keep_token = int(max_token * threshold)
compressed = []
# Always keep system messages
# NOTE: Assume they do not exceed token limit
system_msg_val = self._system_msg("")["role"]
system_msgs = []
for i, msg in enumerate(messages):
if msg["role"] == system_msg_val:
system_msgs.append(msg)
else:
user_assistant_msgs = messages[i:]
break
# system_msgs = [msg for msg in messages if msg["role"] == system_msg_val]
# user_assistant_msgs = [msg for msg in messages if msg["role"] != system_msg_val]
compressed.extend(system_msgs)
current_token_count += self.count_tokens(system_msgs)
if compress_type in [CompressType.POST_CUT_BY_TOKEN, CompressType.POST_CUT_BY_MSG]:
# Under keep_token constraint, keep as many latest messages as possible
for i, msg in enumerate(reversed(user_assistant_msgs)):
token_count = self.count_tokens([msg])
if current_token_count + token_count <= keep_token:
compressed.insert(len(system_msgs), msg)
current_token_count += token_count
else:
if compress_type == CompressType.POST_CUT_BY_TOKEN or len(compressed) == len(system_msgs):
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
truncated_content = msg["content"][-(keep_token - current_token_count) :]
compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content})
logger.warning(
f"Truncated messages with {compress_type} to fit within the token limit. "
f"The first user or assistant message after truncation (originally the {i}-th message from last): {compressed[len(system_msgs)]}."
)
break
elif compress_type in [CompressType.PRE_CUT_BY_TOKEN, CompressType.PRE_CUT_BY_MSG]:
# Under keep_token constraint, keep as many earliest messages as possible
for i, msg in enumerate(user_assistant_msgs):
token_count = self.count_tokens([msg])
if current_token_count + token_count <= keep_token:
compressed.append(msg)
current_token_count += token_count
else:
if compress_type == CompressType.PRE_CUT_BY_TOKEN or len(compressed) == len(system_msgs):
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
truncated_content = msg["content"][: keep_token - current_token_count]
compressed.append({"role": msg["role"], "content": truncated_content})
logger.warning(
f"Truncated messages with {compress_type} to fit within the token limit. "
f"The last user or assistant message after truncation (originally the {i}-th message): {compressed[-1]}."
)
break
return compressed

View file

@ -300,3 +300,9 @@ class OpenAILLM(BaseLLM):
img_url_or_b64 = item.url if resp_format == "url" else item.b64_json
imgs.append(decode_image(img_url_or_b64))
return imgs
def count_tokens(self, messages: list[dict]) -> int:
try:
return count_message_tokens(messages, self.config.model)
except:
return super().count_tokens(messages)

View file

@ -215,6 +215,7 @@ TOKEN_MAX = {
"deepseek/deepseek-chat": 128000, # end, for openrouter
"deepseek-chat": 128000,
"deepseek-coder": 128000,
"deepseek-ai/DeepSeek-Coder-V2-Instruct": 32000, # siliconflow
}
@ -319,4 +320,4 @@ def get_max_completion_tokens(messages: list[dict], model: str, default: int) ->
"""
if model not in TOKEN_MAX:
return default
return TOKEN_MAX[model] - count_message_tokens(messages) - 1
return TOKEN_MAX[model] - count_message_tokens(messages, model) - 1

View file

@ -8,6 +8,7 @@
import pytest
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
@ -104,3 +105,61 @@ async def test_async_base_llm():
# resp = await base_llm.aask_code([prompt])
# assert resp == default_resp_cont
@pytest.mark.parametrize("compress_type", list(CompressType))
def test_compress_messages_no_effect(compress_type):
base_llm = MockBaseLLM()
messages = [
{"role": "system", "content": "first system msg"},
{"role": "system", "content": "second system msg"},
]
for i in range(5):
messages.append({"role": "user", "content": f"u{i}"})
messages.append({"role": "assistant", "content": f"a{i}"})
compressed = base_llm.compress_messages(messages, compress_type=compress_type)
# should take no effect for short context
assert compressed == messages
@pytest.mark.parametrize("compress_type", CompressType.cut_types())
def test_compress_messages_long(compress_type):
base_llm = MockBaseLLM()
base_llm.config.model = "test_llm"
max_token_limit = 100
messages = [
{"role": "system", "content": "first system msg"},
{"role": "system", "content": "second system msg"},
]
for i in range(100):
messages.append({"role": "user", "content": f"u{i}" * 10}) # ~2x10x0.5 = 10 tokens
messages.append({"role": "assistant", "content": f"a{i}" * 10})
compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit)
print(compressed)
print(len(compressed))
assert 3 <= len(compressed) < len(messages)
assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system"
assert compressed[2]["role"] != "system"
def test_long_messages_no_compress():
base_llm = MockBaseLLM()
messages = [{"role": "user", "content": "1" * 10000}] * 10000
compressed = base_llm.compress_messages(messages)
assert len(compressed) == len(messages)
@pytest.mark.parametrize("compress_type", CompressType.cut_types())
def test_compress_messages_long_no_sys_msg(compress_type):
base_llm = MockBaseLLM()
base_llm.config.model = "test_llm"
max_token_limit = 100
messages = [{"role": "user", "content": "1" * 10000}]
compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit)
print(compressed)
assert compressed
assert len(compressed[0]["content"]) < len(messages[0]["content"])

View file

@ -9,6 +9,7 @@ from openai.types.chat.chat_completion import Choice, CompletionUsage
from openai.types.chat.chat_completion_message_tool_call import Function
from PIL import Image
from metagpt.configs.compress_msg_config import CompressType
from metagpt.const import TEST_DATA_PATH
from metagpt.llm import LLM
from metagpt.logs import logger
@ -164,3 +165,63 @@ async def test_openai_acompletion(mocker):
assert resp.usage == usage
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)
def test_count_tokens():
llm = LLM()
llm.config.model = "gpt-4o"
messages = [
llm._system_msg("some system msg"),
llm._system_msg("some system message 2"),
llm._user_msg("user 1"),
llm._assistant_msg("assistant 1"),
llm._user_msg("user 1"),
llm._assistant_msg("assistant 2"),
]
cnt = llm.count_tokens(messages)
assert cnt == 47
def test_count_tokens_long():
llm = LLM()
llm.config.model = "gpt-4-0613"
test_msg_content = " ".join([str(i) for i in range(100000)])
messages = [
llm._system_msg("You are a helpful assistant"),
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
cnt = llm.count_tokens(messages) # 299023, ~300k
assert 290000 <= cnt <= 300000
llm.config.model = "test_llm" # a non-openai model, will use heuristics base count_tokens
cnt = llm.count_tokens(messages) # 294474, ~300k, ~2% difference
assert 290000 <= cnt <= 300000
@pytest.mark.skip
@pytest.mark.asyncio
async def test_aask_long():
llm = LLM()
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
llm.config.compress_type = CompressType.POST_CUT_BY_TOKEN
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
messages = [
llm._system_msg("You are a helpful assistant"),
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
await llm.aask(messages) # should not fail with context truncated
@pytest.mark.skip
@pytest.mark.asyncio
async def test_aask_long_no_compress():
llm = LLM()
llm.config.model = "deepseek-ai/DeepSeek-Coder-V2-Instruct" # deepseek-coder on siliconflow, limit 32k
# Not specifying llm.config.compress_type will use default "", no compress
test_msg_content = " ".join([str(i) for i in range(100000)]) # corresponds to ~300k tokens
messages = [
llm._system_msg("You are a helpful assistant"),
llm._user_msg(test_msg_content + " what's the first number you see?"),
]
with pytest.raises(Exception):
await llm.aask(messages) # should fail