diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 98f6be62d..18e408d2e 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -34,7 +34,13 @@ from metagpt.tools.libs.browser import Browser from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool -from metagpt.utils.common import CodeParser, any_to_str +from metagpt.utils.common import ( + CodeParser, + any_to_str, + encode_image, + extract_image_paths, + is_support_image_input, +) from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output from metagpt.utils.report import ThoughtReporter @@ -129,7 +135,7 @@ class RoleZero(Role): def _update_tool_execution(self): pass - + async def _think(self) -> bool: """Useful in 'react' mode. Use LLM to decide whether and what to do next.""" # Compatibility @@ -171,6 +177,7 @@ class RoleZero(Role): ### Recent Observation ### memory = self.rc.memory.get(self.memory_k) memory = await self.parse_browser_actions(memory) + memory = self.parse_images(memory) req = self.llm.format_msg(memory + [UserMessage(content=prompt)]) async with ThoughtReporter(enable_llm_stream=True) as reporter: @@ -195,8 +202,8 @@ class RoleZero(Role): The `RoleZeroSerializer` extracts essential parts of `req` for the experience pool, trimming lengthy entries to retain only necessary parts. """ return await self.llm.aask(req, system_msgs=system_msgs) - - async def parse_browser_actions(self, memory: List[Message]) -> List[Message]: + + async def parse_browser_actions(self, memory: list[Message]) -> list[Message]: if not self.browser.is_empty_page: pattern = re.compile(r"Command Browser\.(\w+) executed") for index, msg in zip(range(len(memory), 0, -1), memory[::-1]): @@ -205,6 +212,15 @@ class RoleZero(Role): break return memory + def parse_images(self, memory: list[Message]) -> list[Message]: + if not is_support_image_input(self.llm.model): + return memory + for i, msg in enumerate(memory): + if msg.role == "user" and isinstance(msg.content, str) and extract_image_paths(msg.content): + images = [encode_image(path) for path in extract_image_paths(msg.content)] + memory[i] = self.llm._user_msg_with_imgs(msg.content, images=images) + return memory + async def _act(self) -> Message: if self.use_fixed_sop: return await super()._act() @@ -261,7 +277,7 @@ class RoleZero(Role): context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) intent_result = await self.llm.aask(context) - if "QUICK" in intent_result or "AMBIGUOUS " in intent_result: # llm call with the original context + if "QUICK" in intent_result or "AMBIGUOUS " in intent_result: # llm call with the original context async with ThoughtReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "quick"}) answer = await self.llm.aask(self.llm.format_msg(memory)) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 42905c649..10c79ffc3 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -820,6 +820,19 @@ def decode_image(img_url_or_b64: str) -> Image: return img +def is_support_image_input(model_name: str) -> bool: + # model name can be gpt-4o-2024-08-06 + support_models = ["gpt-4o", "gpt-4o-mini"] # FIXME: hard code for now + return any([m in model_name for m in support_models]) + + +def extract_image_paths(content: str) -> bool: + # We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx" + pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff)" + image_paths = re.findall(pattern, content) + return image_paths + + def log_and_reraise(retry_state: RetryCallState): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( diff --git a/tests/metagpt/environment/mgx_env/run_mgx_env.py b/tests/metagpt/environment/mgx_env/run_mgx_env.py index b6d5341de..69efb32da 100644 --- a/tests/metagpt/environment/mgx_env/run_mgx_env.py +++ b/tests/metagpt/environment/mgx_env/run_mgx_env.py @@ -82,9 +82,12 @@ def send_human_input(env, stop_event): GAME_REQ = "create a 2048 game" +GAME_REQ_ZH = "写一个贪吃蛇游戏" WEB_GAME_REQ = "Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard." WEB_GAME_REQ_DEPLOY = "Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard. When finished, deploy the game to public at port 8090." -SIMPLE_REQ = "print statistic summary of sklearn iris dataset" +TODO_APP_REQ = "Create a website widget for TODO list management. Users should be able to add, mark as complete, and delete tasks. Include features like prioritization, due dates, and categories. Make it visually appealing, responsive, and user-friendly. Use HTML, CSS, and JavaScript. Consider additional features like notifications or task export. Keep it simple and enjoyable for users.dont use vue or react.dont use third party library, use localstorage to save data." +FLAPPY_BIRD_REQ = "write a flappy bird game in pygame, code only" +SIMPLE_DATA_REQ = "load sklearn iris dataset and print a statistic summary" WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, and train a model to predict wine class (20% as validation), and show validation accuracy." PAPER_LIST_REQ = """ Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, @@ -95,10 +98,26 @@ Get products data from website https://scrapeme.live/shop/ and save it as a csv **Notice: Firstly parse the web page encoding and the text HTML structure; The first page product name, price, product URL, and image URL must be saved in the csv;** """ +NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**; +下面是一个大致流程, 你会根据每一步的运行结果对当前计划中的任务做出适当调整: +1. 爬取并本地保存html结构; +2. 直接打印第7个*`快讯`*关键词后2000个字符的html内容, 作为*快讯的html内容示例*; +3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间; +4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。 +5. 将全部结果存在本地csv中 +**Notice: view the page element before writing scraping code** +""" data_path = "data/titanic" train_path = f"{data_path}/split_train.csv" eval_path = f"{data_path}/split_eval.csv" TITANIC_REQ = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{train_path}', eval data path: '{eval_path}'." +CALIFORNIA_HOUSING_REQ = """ +Analyze the 'Canifornia-housing-dataset' using https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_california_housing.html#sklearn.datasets.fetch_california_housing to predict the median house value. you need to perfrom data preprocessing, feature engineering and finally modeling to predict the target. Use machine learning techniques such as linear regression (including ridge regression and lasso regression), random forest, XGBoost. You also need to report the MSE on the test dataset +""" +STOCK_REQ = """Import NVIDIA Corporation (NVDA) stock price data from Yahoo Finance, focusing on historical closing prices from the past 5 years. +Summary statistics (mean, median, standard deviation, etc.) to understand the central tendency and dispersion of closingprices. Analyze the data for any noticeable trends, patterns, or anomalies over time, potentially using rolling averages or percentage changes. +Create a pot to visualize all the data analysis. Reserve 20% of the dataset for validaation. Train a predictive model on the training set. Report the modeel's validation accuracy, and visualize the result of prediction result. +""" FIX_ISSUE1 = """ Write a fix for this issue: https://github.com/langchain-ai/langchain/issues/20453, you can fix it on this repo https://github.com/garylin2099/langchain, @@ -123,6 +142,7 @@ PUSH_PR_REQ = """ clone https://github.com/garylin2099/simple_calculator, checkout a new branch named test-branch, add an empty file test_file.py to the repo. Commit your changes and push, finally, create a PR to the master branch of https://github.com/mannaandpoem/simple_calculator. """ +IMAGE2CODE_REQ = "Please write a frontend web page similar to this image /Users/gary/Files/temp/workspace/temp_img.png, I want the same title and color. code only" TL_CHAT1 = """Summarize the paper for me""" # expecting clarification TL_CHAT2 = """Solve the issue at this link""" # expecting clarification @@ -134,6 +154,11 @@ TL_CHAT7 = """Jean has 30 lollipops. Jean eats 2 of the lollipops. With the rema TL_CHAT9 = """What's your name?""" TL_CHAT10 = "Hi" TL_CHAT11 = "Tell me about your team" +TL_CHAT12 = "What can you do" +CODING_REQ1 = "写一个java的hello world程序" +CODING_REQ2 = "python里的装饰器是什么" +CODING_REQ3 = "python里的装饰器是怎么用的,给我个例子" + if __name__ == "__main__": # NOTE: Add access_token to test github issue fixing diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 75e8ef4ad..06838b7c7 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -29,7 +29,9 @@ from metagpt.utils.common import ( awrite, check_cmd_exists, concat_namespace, + extract_image_paths, import_class_inst, + is_support_image_input, parse_recipient, print_members, read_file_block, @@ -215,5 +217,24 @@ class TestGetProjectRoot: assert data == content +def test_extract_image_paths(): + content = """ + Here are some image paths /home/user/images/photo1.jpg /home/user/images/photo2.png + # /absolute/path/to/image.gif""" + assert extract_image_paths(content) == [ + "/home/user/images/photo1.jpg", + "/home/user/images/photo2.png", + "/absolute/path/to/image.gif", + ] + + content = "no image path" + assert not extract_image_paths(content) + + +def test_is_support_image_input(): + assert is_support_image_input("gpt-4o-2024-08-06") + assert not is_support_image_input("deepseek-coder") + + if __name__ == "__main__": pytest.main([__file__, "-s"])