diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 4c7bb4cbc..68b1f78a6 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -303,7 +303,8 @@ class BaseLLM(ABC): keep_token = int(max_token * threshold) compressed = [] - # Always keep system messages and assume they do not exceed token limit + # Always keep system messages + # NOTE: Assume they do not exceed token limit system_msg_val = self._system_msg("")["role"] system_msgs = [] for i, msg in enumerate(messages): @@ -319,40 +320,37 @@ class BaseLLM(ABC): if compress_type in ["post_cut_by_msg", "post_cut_by_token"]: # Under keep_token constraint, keep as many latest messages as possible - for msg in reversed(user_assistant_msgs): + for i, msg in enumerate(reversed(user_assistant_msgs)): token_count = self.count_tokens([msg]) if current_token_count + token_count <= keep_token: compressed.insert(len(system_msgs), msg) current_token_count += token_count else: - truncated_msg_idx = len(system_msgs) - 1 - if compress_type == "post_cut_by_token": - # Truncate the message to fit within the remaining token count; Otherwise, discard the msg + if compress_type == "post_cut_by_token" or len(compressed) == len(system_msgs): + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][-(keep_token - current_token_count) :] compressed.insert(len(system_msgs), {"role": msg["role"], "content": truncated_content}) - # after post truncation, the first message after the system message is the one truncated - truncated_msg_idx = len(system_msgs) logger.warning( f"Truncated messages with {compress_type} to fit within the token limit. " - f"The first user or assistant message after truncation is: {compressed[truncated_msg_idx]}." + f"The first user or assistant message after truncation (originally the {i}-th message from last): {compressed[len(system_msgs)]}." ) break elif compress_type in ["pre_cut_by_msg", "pre_cut_by_token"]: # Under keep_token constraint, keep as many earliest messages as possible - for msg in user_assistant_msgs: + for i, msg in enumerate(user_assistant_msgs): token_count = self.count_tokens([msg]) if current_token_count + token_count <= keep_token: compressed.append(msg) current_token_count += token_count else: - if compress_type == "pre_cut_by_token": - # Truncate the message to fit within the remaining token count; Otherwise, discard the msg + if compress_type == "pre_cut_by_token" or len(compressed) == len(system_msgs): + # Truncate the message to fit within the remaining token count; Otherwise, discard the msg. If compressed has no user or assistant message, enforce cutting by token truncated_content = msg["content"][: keep_token - current_token_count] compressed.append({"role": msg["role"], "content": truncated_content}) logger.warning( f"Truncated messages with {compress_type} to fit within the token limit. " - f"The last message after truncation is: {compressed[-1]}." + f"The last user or assistant message after truncation (originally the {i}-th message): {compressed[-1]}." ) break diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index c9c13650a..250c3b5c4 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -41,7 +41,6 @@ class MockBaseLLM(BaseLLM): return default_resp_cont -@pytest.mark.skip def test_base_llm(): message = Message(role="user", content="hello") assert "role" in message.to_dict() @@ -93,7 +92,6 @@ def test_base_llm(): # assert resp == default_resp_cont -@pytest.mark.skip @pytest.mark.asyncio async def test_async_base_llm(): base_llm = MockBaseLLM() @@ -147,3 +145,19 @@ def test_compress_messages_long(compress_type): assert len(compressed) < len(messages) assert compressed[0]["role"] == "system" and compressed[1]["role"] == "system" assert compressed[2]["role"] != "system" + + +@pytest.mark.parametrize( + "compress_type", ["post_cut_by_msg", "post_cut_by_token", "pre_cut_by_msg", "pre_cut_by_token"] +) +def test_compress_messages_long_no_sys_msg(compress_type): + base_llm = MockBaseLLM() + base_llm.config.model = "test_llm" + max_token_limit = 100 + + messages = [{"role": "user", "content": "1" * 10000}] + compressed = base_llm.compress_messages(messages, compress_type=compress_type, max_token=max_token_limit) + + print(compressed) + assert compressed + assert len(compressed[0]["content"]) < len(messages[0]["content"])