first version of long context compression

This commit is contained in:
garylin2099 2024-07-31 22:07:48 +08:00
parent 17247b5518
commit a6104f3931
5 changed files with 190 additions and 3 deletions

View file

@ -10,7 +10,7 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Literal, Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
@ -28,6 +28,7 @@ from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.token_counter import TOKEN_MAX
class BaseLLM(ABC):
@ -147,7 +148,9 @@ class BaseLLM(ABC):
else:
message.extend(msg)
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
compressed_message = self.compress_messages(message, compress_type="post_cut_by_token")
rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout))
# rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp
def _extract_assistant_rsp(self, context):
@ -264,3 +267,93 @@ class BaseLLM(ABC):
def get_timeout(self, timeout: int) -> int:
return timeout or self.config.timeout or LLM_API_TIMEOUT
def count_tokens(self, messages: list[dict]) -> int:
# A very raw heuristic to count tokens, taking reference from:
# https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
# https://platform.deepseek.com/api-docs/#token--token-usage
# The heuristics is a huge overestimate for English text, e.g., and should be overwrittem with accurate token count function in inherited class
# logger.warning("Base count_tokens is not accurate and should be overwritten.")
return sum([int(len(msg["content"]) * 0.5) for msg in messages])
def compress_messages(
self,
messages: list[dict],
compress_type: Literal["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token", ""] = "",
max_token: int = 128000,
threshold: float = 0.8,
) -> list[dict]:
"""Compress messages to fit within the token limit.
Args:
messages (list[dict]): List of messages to compress.
compress_type (str, optional): Compression strategy.
- "post_cut_by_msg": Keep as many latest messages as possible.
- "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message.
- "pre_cut_by_msg": Keep as many earliest messages as possible.
- "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message.
- "": No compression. Default value.
max_token (int, optional): Maximum token limit. Defaults to 128000. Not effective if token limit can be found in TOKEN_MAX.
threshold (float): Token limit threshold. Defaults to 0.8. Reserve 20% of the token limit for completion message.
"""
if not compress_type:
return messages
current_token_count = 0
max_token = TOKEN_MAX.get(self.config.model, max_token)
keep_token = int(max_token * threshold)
compressed = []
# Always keep system messages and assume they do not exceed token limit
system_msg_val = self._system_msg("")["role"]
system_msgs = []
for i, msg in enumerate(messages):
if msg["role"] == system_msg_val:
system_msgs.append(msg)
else:
user_assistant_msgs = messages[i:]
break
# system_msgs = [msg for msg in messages if msg["role"] == system_msg_val]
# user_assistant_msgs = [msg for msg in messages if msg["role"] != system_msg_val]
compressed.extend(system_msgs)
current_token_count += self.count_tokens(system_msgs)
if compress_type in ["post_cut_by_msg", "post_cut_by_token"]:
# Under keep_token constraint, keep as many latest messages as possible
for msg in reversed(user_assistant_msgs):
token_count = self.count_tokens([msg])
if current_token_count + token_count <= keep_token:
compressed.insert(len(system_msgs), msg)
current_token_count += token_count
else:
truncated_msg_idx = len(system_msgs) - 1
if compress_type == "post_cut_by_token":
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg
truncated_content = msg["content"][-(keep_token - current_token_count) :]
compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content})
# after post truncation, the first message after the system message is the one truncated
truncated_msg_idx = len(system_msgs)
logger.warning(
f"Truncated messages with {compress_type} to fit within the token limit. "
f"The first user or assistant message after truncation is: {compressed[truncated_msg_idx]}."
)
break
elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]:
# Under keep_token constraint, keep as many earliest messages as possible
for msg in user_assistant_msgs:
token_count = self.count_tokens([msg])
if current_token_count + token_count <= keep_token:
compressed.append(msg)
current_token_count += token_count
else:
if compress_type == "pre_cut_by_token":
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg
truncated_content = msg["content"][: keep_token - current_token_count]
compressed.append({"role": msg["role"], "content": truncated_content})
logger.warning(
f"Truncated messages with {compress_type} to fit within the token limit. "
f"The last message after truncation is: {compressed[-1]}."
)
break
return compressed

View file

@ -300,3 +300,9 @@ class OpenAILLM(BaseLLM):
img_url_or_b64 = item.url if resp_format == "url" else item.b64_json
imgs.append(decode_image(img_url_or_b64))
return imgs
def count_tokens(self, messages: list[dict]) -> int:
try:
return count_message_tokens(messages, self.config.model)
except:
return super().count_tokens(messages)

View file

@ -215,6 +215,7 @@ TOKEN_MAX = {
"deepseek/deepseek-chat": 128000, # end, for openrouter
"deepseek-chat": 128000,
"deepseek-coder": 128000,
"deepseek-ai/DeepSeek-Coder-V2-Instruct": 32000, # siliconflow
}
@ -319,4 +320,4 @@ def get_max_completion_tokens(messages: list[dict], model: str, default: int) ->
"""
if model not in TOKEN_MAX:
return default
return TOKEN_MAX[model] - count_message_tokens(messages) - 1
return TOKEN_MAX[model] - count_message_tokens(messages, model) - 1