From 78db50b4327e6f453a26b19e6b08477f2b5f1fd2 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Tue, 13 Aug 2024 19:30:49 +0800 Subject: [PATCH 1/3] allow image input for rolezero --- metagpt/roles/di/role_zero.py | 26 ++++++++++++++---- metagpt/utils/common.py | 13 +++++++++ .../environment/mgx_env/run_mgx_env.py | 27 ++++++++++++++++++- tests/metagpt/utils/test_common.py | 21 +++++++++++++++ 4 files changed, 81 insertions(+), 6 deletions(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 98f6be62d..18e408d2e 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -34,7 +34,13 @@ from metagpt.tools.libs.browser import Browser from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool -from metagpt.utils.common import CodeParser, any_to_str +from metagpt.utils.common import ( + CodeParser, + any_to_str, + encode_image, + extract_image_paths, + is_support_image_input, +) from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output from metagpt.utils.report import ThoughtReporter @@ -129,7 +135,7 @@ class RoleZero(Role): def _update_tool_execution(self): pass - + async def _think(self) -> bool: """Useful in 'react' mode. Use LLM to decide whether and what to do next.""" # Compatibility @@ -171,6 +177,7 @@ class RoleZero(Role): ### Recent Observation ### memory = self.rc.memory.get(self.memory_k) memory = await self.parse_browser_actions(memory) + memory = self.parse_images(memory) req = self.llm.format_msg(memory + [UserMessage(content=prompt)]) async with ThoughtReporter(enable_llm_stream=True) as reporter: @@ -195,8 +202,8 @@ class RoleZero(Role): The `RoleZeroSerializer` extracts essential parts of `req` for the experience pool, trimming lengthy entries to retain only necessary parts. """ return await self.llm.aask(req, system_msgs=system_msgs) - - async def parse_browser_actions(self, memory: List[Message]) -> List[Message]: + + async def parse_browser_actions(self, memory: list[Message]) -> list[Message]: if not self.browser.is_empty_page: pattern = re.compile(r"Command Browser\.(\w+) executed") for index, msg in zip(range(len(memory), 0, -1), memory[::-1]): @@ -205,6 +212,15 @@ class RoleZero(Role): break return memory + def parse_images(self, memory: list[Message]) -> list[Message]: + if not is_support_image_input(self.llm.model): + return memory + for i, msg in enumerate(memory): + if msg.role == "user" and isinstance(msg.content, str) and extract_image_paths(msg.content): + images = [encode_image(path) for path in extract_image_paths(msg.content)] + memory[i] = self.llm._user_msg_with_imgs(msg.content, images=images) + return memory + async def _act(self) -> Message: if self.use_fixed_sop: return await super()._act() @@ -261,7 +277,7 @@ class RoleZero(Role): context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) intent_result = await self.llm.aask(context) - if "QUICK" in intent_result or "AMBIGUOUS " in intent_result: # llm call with the original context + if "QUICK" in intent_result or "AMBIGUOUS " in intent_result: # llm call with the original context async with ThoughtReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "quick"}) answer = await self.llm.aask(self.llm.format_msg(memory)) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 42905c649..10c79ffc3 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -820,6 +820,19 @@ def decode_image(img_url_or_b64: str) -> Image: return img +def is_support_image_input(model_name: str) -> bool: + # model name can be gpt-4o-2024-08-06 + support_models = ["gpt-4o", "gpt-4o-mini"] # FIXME: hard code for now + return any([m in model_name for m in support_models]) + + +def extract_image_paths(content: str) -> bool: + # We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx" + pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff)" + image_paths = re.findall(pattern, content) + return image_paths + + def log_and_reraise(retry_state: RetryCallState): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( diff --git a/tests/metagpt/environment/mgx_env/run_mgx_env.py b/tests/metagpt/environment/mgx_env/run_mgx_env.py index b6d5341de..69efb32da 100644 --- a/tests/metagpt/environment/mgx_env/run_mgx_env.py +++ b/tests/metagpt/environment/mgx_env/run_mgx_env.py @@ -82,9 +82,12 @@ def send_human_input(env, stop_event): GAME_REQ = "create a 2048 game" +GAME_REQ_ZH = "写一个贪吃蛇游戏" WEB_GAME_REQ = "Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard." WEB_GAME_REQ_DEPLOY = "Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard. When finished, deploy the game to public at port 8090." -SIMPLE_REQ = "print statistic summary of sklearn iris dataset" +TODO_APP_REQ = "Create a website widget for TODO list management. Users should be able to add, mark as complete, and delete tasks. Include features like prioritization, due dates, and categories. Make it visually appealing, responsive, and user-friendly. Use HTML, CSS, and JavaScript. Consider additional features like notifications or task export. Keep it simple and enjoyable for users.dont use vue or react.dont use third party library, use localstorage to save data." +FLAPPY_BIRD_REQ = "write a flappy bird game in pygame, code only" +SIMPLE_DATA_REQ = "load sklearn iris dataset and print a statistic summary" WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, and train a model to predict wine class (20% as validation), and show validation accuracy." PAPER_LIST_REQ = """ Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, @@ -95,10 +98,26 @@ Get products data from website https://scrapeme.live/shop/ and save it as a csv **Notice: Firstly parse the web page encoding and the text HTML structure; The first page product name, price, product URL, and image URL must be saved in the csv;** """ +NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**; +下面是一个大致流程, 你会根据每一步的运行结果对当前计划中的任务做出适当调整: +1. 爬取并本地保存html结构; +2. 直接打印第7个*`快讯`*关键词后2000个字符的html内容, 作为*快讯的html内容示例*; +3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间; +4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。 +5. 将全部结果存在本地csv中 +**Notice: view the page element before writing scraping code** +""" data_path = "data/titanic" train_path = f"{data_path}/split_train.csv" eval_path = f"{data_path}/split_eval.csv" TITANIC_REQ = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{train_path}', eval data path: '{eval_path}'." +CALIFORNIA_HOUSING_REQ = """ +Analyze the 'Canifornia-housing-dataset' using https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_california_housing.html#sklearn.datasets.fetch_california_housing to predict the median house value. you need to perfrom data preprocessing, feature engineering and finally modeling to predict the target. Use machine learning techniques such as linear regression (including ridge regression and lasso regression), random forest, XGBoost. You also need to report the MSE on the test dataset +""" +STOCK_REQ = """Import NVIDIA Corporation (NVDA) stock price data from Yahoo Finance, focusing on historical closing prices from the past 5 years. +Summary statistics (mean, median, standard deviation, etc.) to understand the central tendency and dispersion of closingprices. Analyze the data for any noticeable trends, patterns, or anomalies over time, potentially using rolling averages or percentage changes. +Create a pot to visualize all the data analysis. Reserve 20% of the dataset for validaation. Train a predictive model on the training set. Report the modeel's validation accuracy, and visualize the result of prediction result. +""" FIX_ISSUE1 = """ Write a fix for this issue: https://github.com/langchain-ai/langchain/issues/20453, you can fix it on this repo https://github.com/garylin2099/langchain, @@ -123,6 +142,7 @@ PUSH_PR_REQ = """ clone https://github.com/garylin2099/simple_calculator, checkout a new branch named test-branch, add an empty file test_file.py to the repo. Commit your changes and push, finally, create a PR to the master branch of https://github.com/mannaandpoem/simple_calculator. """ +IMAGE2CODE_REQ = "Please write a frontend web page similar to this image /Users/gary/Files/temp/workspace/temp_img.png, I want the same title and color. code only" TL_CHAT1 = """Summarize the paper for me""" # expecting clarification TL_CHAT2 = """Solve the issue at this link""" # expecting clarification @@ -134,6 +154,11 @@ TL_CHAT7 = """Jean has 30 lollipops. Jean eats 2 of the lollipops. With the rema TL_CHAT9 = """What's your name?""" TL_CHAT10 = "Hi" TL_CHAT11 = "Tell me about your team" +TL_CHAT12 = "What can you do" +CODING_REQ1 = "写一个java的hello world程序" +CODING_REQ2 = "python里的装饰器是什么" +CODING_REQ3 = "python里的装饰器是怎么用的,给我个例子" + if __name__ == "__main__": # NOTE: Add access_token to test github issue fixing diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 75e8ef4ad..06838b7c7 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -29,7 +29,9 @@ from metagpt.utils.common import ( awrite, check_cmd_exists, concat_namespace, + extract_image_paths, import_class_inst, + is_support_image_input, parse_recipient, print_members, read_file_block, @@ -215,5 +217,24 @@ class TestGetProjectRoot: assert data == content +def test_extract_image_paths(): + content = """ + Here are some image paths /home/user/images/photo1.jpg /home/user/images/photo2.png + # /absolute/path/to/image.gif""" + assert extract_image_paths(content) == [ + "/home/user/images/photo1.jpg", + "/home/user/images/photo2.png", + "/absolute/path/to/image.gif", + ] + + content = "no image path" + assert not extract_image_paths(content) + + +def test_is_support_image_input(): + assert is_support_image_input("gpt-4o-2024-08-06") + assert not is_support_image_input("deepseek-coder") + + if __name__ == "__main__": pytest.main([__file__, "-s"]) From dce5502c07b79147b0870507b0a8e6fea50e2496 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Tue, 13 Aug 2024 20:03:46 +0800 Subject: [PATCH 2/3] check image path before encoding --- metagpt/roles/di/role_zero.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 424c7bffb..f1339ef32 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -2,6 +2,7 @@ from __future__ import annotations import inspect import json +import os import re import traceback from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple @@ -222,7 +223,10 @@ class RoleZero(Role): return memory for i, msg in enumerate(memory): if msg.role == "user" and isinstance(msg.content, str) and extract_image_paths(msg.content): - images = [encode_image(path) for path in extract_image_paths(msg.content)] + images = [] + for path in extract_image_paths(msg.content): + if os.path.exists(path): + images.append(encode_image(path)) memory[i] = self.llm._user_msg_with_imgs(msg.content, images=images) return memory From e9984f2bf82d9cd86d0056d142b44471da9078e6 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Wed, 14 Aug 2024 20:12:17 +0800 Subject: [PATCH 3/3] attach images to message --- metagpt/environment/mgx/mgx_env.py | 17 ++++++++-- metagpt/provider/base_llm.py | 12 ++++++-- metagpt/roles/di/role_zero.py | 25 ++++++--------- metagpt/utils/common.py | 14 +++++---- tests/metagpt/provider/test_base_llm.py | 41 ++++++++++++++++++++++++- tests/metagpt/utils/test_common.py | 7 ++--- 6 files changed, 83 insertions(+), 33 deletions(-) diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index fae386952..8bb3fc823 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from metagpt.actions import ( UserRequirement, WriteDesign, @@ -6,12 +8,12 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT +from metagpt.const import AGENT, IMAGES from metagpt.environment.base_env import Environment from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role from metagpt.schema import Message, SerializationMixin -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_str, any_to_str_set, extract_and_encode_images class MGXEnv(Environment, SerializationMixin): @@ -27,6 +29,8 @@ class MGXEnv(Environment, SerializationMixin): def publish_message(self, message: Message, user_defined_recipient: str = "", publicer: str = "") -> bool: """let the team leader take over message publishing""" + message = self.attach_images(message) # for multi-modal message + tl = self.get_role("Mike") # TeamLeader's name is Mike if user_defined_recipient: @@ -119,9 +123,16 @@ class MGXEnv(Environment, SerializationMixin): converted_msg.role = "assistant" sent_from = converted_msg.metadata[AGENT] if AGENT in converted_msg.metadata else converted_msg.sent_from converted_msg.content = ( - f"[Message] from {sent_from if sent_from else 'User'} to {converted_msg.send_to}: {converted_msg.content}" + f"[Message] from {sent_from or 'User'} to {converted_msg.send_to}: {converted_msg.content}" ) return converted_msg + def attach_images(self, message: Message) -> Message: + if message.role == "user": + images = extract_and_encode_images(message.content) + if images: + message.add_metadata(IMAGES, images) + return message + def __repr__(self): return "MGXEnv()" diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index ac09c19f7..813e77d95 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -24,8 +24,9 @@ from tenacity import ( from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig -from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT +from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger +from metagpt.provider.constant import MULTI_MODAL_MODELS from metagpt.schema import Message from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs @@ -50,7 +51,7 @@ class BaseLLM(ABC): pass def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]: - if images: + if images and self.support_image_input(): # as gpt-4v, chat with image return self._user_msg_with_imgs(msg, images) else: @@ -76,6 +77,9 @@ class BaseLLM(ABC): def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "system", "content": msg} + def support_image_input(self) -> bool: + return any([m in self.config.model for m in MULTI_MODAL_MODELS]) + def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: """convert messages to list[dict].""" from metagpt.schema import Message @@ -91,7 +95,9 @@ class BaseLLM(ABC): assert set(msg.keys()) == set(["role", "content"]) processed_messages.append(msg) elif isinstance(msg, Message): - processed_messages.append(msg.to_dict()) + images = msg.metadata.get(IMAGES) + processed_msg = self._user_msg(msg=msg.content, images=images) if images else msg.to_dict() + processed_messages.append(processed_msg) else: raise ValueError( f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index f1339ef32..cc9d1d1aa 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -2,7 +2,6 @@ from __future__ import annotations import inspect import json -import os import re import traceback from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple @@ -13,6 +12,7 @@ from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions from metagpt.actions.di.run_command import RunCommand from metagpt.actions.search_enhanced_qa import SearchEnhancedQA +from metagpt.const import IMAGES from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer @@ -35,13 +35,7 @@ from metagpt.tools.libs.browser import Browser from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool -from metagpt.utils.common import ( - CodeParser, - any_to_str, - encode_image, - extract_image_paths, - is_support_image_input, -) +from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images from metagpt.utils.repair_llm_raw_output import ( RepairType, repair_escape_error, @@ -219,15 +213,14 @@ class RoleZero(Role): return memory def parse_images(self, memory: list[Message]) -> list[Message]: - if not is_support_image_input(self.llm.model): + if not self.llm.support_image_input(): return memory - for i, msg in enumerate(memory): - if msg.role == "user" and isinstance(msg.content, str) and extract_image_paths(msg.content): - images = [] - for path in extract_image_paths(msg.content): - if os.path.exists(path): - images.append(encode_image(path)) - memory[i] = self.llm._user_msg_with_imgs(msg.content, images=images) + for msg in memory: + if IMAGES in msg.metadata or msg.role != "user": + continue + images = extract_and_encode_images(msg.content) + if images: + msg.add_metadata(IMAGES, images) return memory async def _act(self) -> Message: diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 8f55df8ba..0d8c03a02 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -840,12 +840,6 @@ def decode_image(img_url_or_b64: str) -> Image: return img -def is_support_image_input(model_name: str) -> bool: - # model name can be gpt-4o-2024-08-06 - support_models = ["gpt-4o", "gpt-4o-mini"] # FIXME: hard code for now - return any([m in model_name for m in support_models]) - - def extract_image_paths(content: str) -> bool: # We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx" pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff)" @@ -853,6 +847,14 @@ def extract_image_paths(content: str) -> bool: return image_paths +def extract_and_encode_images(content: str) -> list[str]: + images = [] + for path in extract_image_paths(content): + if os.path.exists(path): + images.append(encode_image(path)) + return images + + def log_and_reraise(retry_state: RetryCallState): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index d34ed62f1..62083a769 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -10,8 +10,9 @@ import pytest from metagpt.configs.compress_msg_config import CompressType from metagpt.configs.llm_config import LLMConfig +from metagpt.const import IMAGES from metagpt.provider.base_llm import BaseLLM -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message, UserMessage from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( default_resp_cont, @@ -163,3 +164,41 @@ def test_compress_messages_long_no_sys_msg(compress_type): print(compressed) assert compressed assert len(compressed[0]["content"]) < len(messages[0]["content"]) + + +def test_format_msg(mocker): + base_llm = MockBaseLLM() + messages = [UserMessage(content="req"), AIMessage(content="rsp")] + formatted_msgs = base_llm.format_msg(messages) + assert formatted_msgs == [{"role": "user", "content": "req"}, {"role": "assistant", "content": "rsp"}] + + +def test_format_msg_w_images(mocker): + base_llm = MockBaseLLM() + base_llm.config.model = "gpt-4o" + msg_w_images = UserMessage(content="req1") + msg_w_images.add_metadata(IMAGES, ["base64 string 1", "base64 string 2"]) + msg_w_empty_images = UserMessage(content="req2") + msg_w_empty_images.add_metadata(IMAGES, []) + messages = [ + msg_w_images, # should be converted + AIMessage(content="rsp"), + msg_w_empty_images, # should not be converted + ] + formatted_msgs = base_llm.format_msg(messages) + assert formatted_msgs == [ + { + "role": "user", + "content": [ + {"type": "text", "text": "req1"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 1"}}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 2"}}, + ], + }, + {"role": "assistant", "content": "rsp"}, + {"role": "user", "content": "req2"}, + ] + + +if name == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 06838b7c7..b85fe229b 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -29,9 +29,9 @@ from metagpt.utils.common import ( awrite, check_cmd_exists, concat_namespace, + extract_and_encode_images, extract_image_paths, import_class_inst, - is_support_image_input, parse_recipient, print_members, read_file_block, @@ -231,9 +231,8 @@ def test_extract_image_paths(): assert not extract_image_paths(content) -def test_is_support_image_input(): - assert is_support_image_input("gpt-4o-2024-08-06") - assert not is_support_image_input("deepseek-coder") +def test_extract_and_encode_images(): + assert not extract_and_encode_images("a non-existing.jpg") if __name__ == "__main__":