diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index fae386952..8bb3fc823 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from metagpt.actions import ( UserRequirement, WriteDesign, @@ -6,12 +8,12 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT +from metagpt.const import AGENT, IMAGES from metagpt.environment.base_env import Environment from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role from metagpt.schema import Message, SerializationMixin -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_str, any_to_str_set, extract_and_encode_images class MGXEnv(Environment, SerializationMixin): @@ -27,6 +29,8 @@ class MGXEnv(Environment, SerializationMixin): def publish_message(self, message: Message, user_defined_recipient: str = "", publicer: str = "") -> bool: """let the team leader take over message publishing""" + message = self.attach_images(message) # for multi-modal message + tl = self.get_role("Mike") # TeamLeader's name is Mike if user_defined_recipient: @@ -119,9 +123,16 @@ class MGXEnv(Environment, SerializationMixin): converted_msg.role = "assistant" sent_from = converted_msg.metadata[AGENT] if AGENT in converted_msg.metadata else converted_msg.sent_from converted_msg.content = ( - f"[Message] from {sent_from if sent_from else 'User'} to {converted_msg.send_to}: {converted_msg.content}" + f"[Message] from {sent_from or 'User'} to {converted_msg.send_to}: {converted_msg.content}" ) return converted_msg + def attach_images(self, message: Message) -> Message: + if message.role == "user": + images = extract_and_encode_images(message.content) + if images: + message.add_metadata(IMAGES, images) + return message + def __repr__(self): return "MGXEnv()" diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index ac09c19f7..813e77d95 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -24,8 +24,9 @@ from tenacity import ( from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig -from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT +from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger +from metagpt.provider.constant import MULTI_MODAL_MODELS from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs @@ -50,7 +51,7 @@ class BaseLLM(ABC): pass def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]: - if images: + if images and self.support_image_input(): # as gpt-4v, chat with image return self._user_msg_with_imgs(msg, images) else: @@ -76,6 +77,9 @@ class BaseLLM(ABC): def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "system", "content": msg} + def support_image_input(self) -> bool: + return any([m in self.config.model for m in MULTI_MODAL_MODELS]) + def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message @@ -91,7 +95,9 @@ class BaseLLM(ABC): assert set(msg.keys()) == set(["role", "content"]) processed_messages.append(msg) elif isinstance(msg, Message): - processed_messages.append(msg.to_dict()) + images = msg.metadata.get(IMAGES) + processed_msg = self._user_msg(msg=msg.content, images=images) if images else msg.to_dict() + processed_messages.append(processed_msg) else: raise ValueError( f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index f1339ef32..cc9d1d1aa 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -2,7 +2,6 @@ from __future__ import annotations import inspect import json -import os import re import traceback from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple @@ -13,6 +12,7 @@ from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions from metagpt.actions.di.run_command import RunCommand from metagpt.actions.search_enhanced_qa import SearchEnhancedQA +from metagpt.const import IMAGES from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer @@ -35,13 +35,7 @@ from metagpt.tools.libs.browser import Browser from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool -from metagpt.utils.common import ( - CodeParser, - any_to_str, - encode_image, - extract_image_paths, - is_support_image_input, -) +from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images from metagpt.utils.repair_llm_raw_output import ( RepairType, repair_escape_error, @@ -219,15 +213,14 @@ class RoleZero(Role): return memory def parse_images(self, memory: list[Message]) -> list[Message]: - if not is_support_image_input(self.llm.model): + if not self.llm.support_image_input(): return memory - for i, msg in enumerate(memory): - if msg.role == "user" and isinstance(msg.content, str) and extract_image_paths(msg.content): - images = [] - for path in extract_image_paths(msg.content): - if os.path.exists(path): - images.append(encode_image(path)) - memory[i] = self.llm._user_msg_with_imgs(msg.content, images=images) + for msg in memory: + if IMAGES in msg.metadata or msg.role != "user": + continue + images = extract_and_encode_images(msg.content) + if images: + msg.add_metadata(IMAGES, images) return memory async def _act(self) -> Message: diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 8f55df8ba..0d8c03a02 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -840,12 +840,6 @@ def decode_image(img_url_or_b64: str) -> Image: return img -def is_support_image_input(model_name: str) -> bool: - # model name can be gpt-4o-2024-08-06 - support_models = ["gpt-4o", "gpt-4o-mini"] # FIXME: hard code for now - return any([m in model_name for m in support_models]) - - def extract_image_paths(content: str) -> bool: # We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx" pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff)" @@ -853,6 +847,14 @@ def extract_image_paths(content: str) -> bool: return image_paths +def extract_and_encode_images(content: str) -> list[str]: + images = [] + for path in extract_image_paths(content): + if os.path.exists(path): + images.append(encode_image(path)) + return images + + def log_and_reraise(retry_state: RetryCallState): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index d34ed62f1..62083a769 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -10,8 +10,9 @@ import pytest from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig +from metagpt.const import IMAGES from metagpt.provider.base_llm import BaseLLM -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message, UserMessage from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( default_resp_cont, @@ -163,3 +164,41 @@ def test_compress_messages_long_no_sys_msg(compress_type): print(compressed) assert compressed assert len(compressed[0]["content"]) < len(messages[0]["content"]) + + +def test_format_msg(mocker): + base_llm = MockBaseLLM() + messages = [UserMessage(content="req"), AIMessage(content="rsp")] + formatted_msgs = base_llm.format_msg(messages) + assert formatted_msgs == [{"role": "user", "content": "req"}, {"role": "assistant", "content": "rsp"}] + + +def test_format_msg_w_images(mocker): + base_llm = MockBaseLLM() + base_llm.config.model = "gpt-4o" + msg_w_images = UserMessage(content="req1") + msg_w_images.add_metadata(IMAGES, ["base64 string 1", "base64 string 2"]) + msg_w_empty_images = UserMessage(content="req2") + msg_w_empty_images.add_metadata(IMAGES, []) + messages = [ + msg_w_images, # should be converted + AIMessage(content="rsp"), + msg_w_empty_images, # should not be converted + ] + formatted_msgs = base_llm.format_msg(messages) + assert formatted_msgs == [ + { + "role": "user", + "content": [ + {"type": "text", "text": "req1"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 1"}}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 2"}}, + ], + }, + {"role": "assistant", "content": "rsp"}, + {"role": "user", "content": "req2"}, + ] + + +if name == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 06838b7c7..b85fe229b 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -29,9 +29,9 @@ from metagpt.utils.common import ( awrite, check_cmd_exists, concat_namespace, + extract_and_encode_images, extract_image_paths, import_class_inst, - is_support_image_input, parse_recipient, print_members, read_file_block, @@ -231,9 +231,8 @@ def test_extract_image_paths(): assert not extract_image_paths(content) -def test_is_support_image_input(): - assert is_support_image_input("gpt-4o-2024-08-06") - assert not is_support_image_input("deepseek-coder") +def test_extract_and_encode_images(): + assert not extract_and_encode_images("a non-existing.jpg") if __name__ == "__main__":