mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
modify detail
This commit is contained in:
parent
a6104f3931
commit
1f73051571
2 changed files with 26 additions and 14 deletions
|
|
@ -303,7 +303,8 @@ class BaseLLM(ABC):
|
|||
keep_token = int(max_token * threshold)
|
||||
compressed = []
|
||||
|
||||
# Always keep system messages and assume they do not exceed token limit
|
||||
# Always keep system messages
|
||||
# NOTE: Assume they do not exceed token limit
|
||||
system_msg_val = self._system_msg("")["role"]
|
||||
system_msgs = []
|
||||
for i, msg in enumerate(messages):
|
||||
|
|
@ -319,40 +320,37 @@ class BaseLLM(ABC):
|
|||
|
||||
if compress_type in ["post_cut_by_msg", "post_cut_by_token"]:
|
||||
# Under keep_token constraint, keep as many latest messages as possible
|
||||
for msg in reversed(user_assistant_msgs):
|
||||
for i, msg in enumerate(reversed(user_assistant_msgs)):
|
||||
token_count = self.count_tokens([msg])
|
||||
if current_token_count + token_count <= keep_token:
|
||||
compressed.insert(len(system_msgs), msg)
|
||||
current_token_count += token_count
|
||||
else:
|
||||
truncated_msg_idx = len(system_msgs) - 1
|
||||
if compress_type == "post_cut_by_token":
|
||||
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg
|
||||
if compress_type == "post_cut_by_token" or len(compressed) == len(system_msgs):
|
||||
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
|
||||
truncated_content = msg["content"][-(keep_token - current_token_count) :]
|
||||
compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content})
|
||||
# after post truncation, the first message after the system message is the one truncated
|
||||
truncated_msg_idx = len(system_msgs)
|
||||
logger.warning(
|
||||
f"Truncated messages with {compress_type} to fit within the token limit. "
|
||||
f"The first user or assistant message after truncation is: {compressed[truncated_msg_idx]}."
|
||||
f"The first user or assistant message after truncation (originally the {i}-th message from last): {compressed[len(system_msgs)]}."
|
||||
)
|
||||
break
|
||||
|
||||
elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]:
|
||||
# Under keep_token constraint, keep as many earliest messages as possible
|
||||
for msg in user_assistant_msgs:
|
||||
for i, msg in enumerate(user_assistant_msgs):
|
||||
token_count = self.count_tokens([msg])
|
||||
if current_token_count + token_count <= keep_token:
|
||||
compressed.append(msg)
|
||||
current_token_count += token_count
|
||||
else:
|
||||
if compress_type == "pre_cut_by_token":
|
||||
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg
|
||||
if compress_type == "pre_cut_by_token" or len(compressed) == len(system_msgs):
|
||||
# Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token
|
||||
truncated_content = msg["content"][: keep_token - current_token_count]
|
||||
compressed.append({"role": msg["role"], "content": truncated_content})
|
||||
logger.warning(
|
||||
f"Truncated messages with {compress_type} to fit within the token limit. "
|
||||
f"The last message after truncation is: {compressed[-1]}."
|
||||
f"The last user or assistant message after truncation (originally the {i}-th message): {compressed[-1]}."
|
||||
)
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ class MockBaseLLM(BaseLLM):
|
|||
return default_resp_cont
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_base_llm():
|
||||
message = Message(role="user", content="hello")
|
||||
assert "role" in message.to_dict()
|
||||
|
|
@ -93,7 +92,6 @@ def test_base_llm():
|
|||
# assert resp == default_resp_cont
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_llm():
|
||||
base_llm = MockBaseLLM()
|
||||
|
|
@ -147,3 +145,19 @@ def test_compress_messages_long(compress_type):
|
|||
assert len(compressed) < len(messages)
|
||||
assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system"
|
||||
assert compressed[2]["role"] != "system"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"]
|
||||
)
|
||||
def test_compress_messages_long_no_sys_msg(compress_type):
|
||||
base_llm = MockBaseLLM()
|
||||
base_llm.config.model = "test_llm"
|
||||
max_token_limit = 100
|
||||
|
||||
messages = [{"role": "user", "content": "1" * 10000}]
|
||||
compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit)
|
||||
|
||||
print(compressed)
|
||||
assert compressed
|
||||
assert len(compressed[0]["content"]) < len(messages[0]["content"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue