diff --git a/.gitignore b/.gitignore
index 922116d12..aa5edd74a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,7 @@
### Python template
# Byte-compiled / optimized / DLL files
-__pycache__/
+__pycache__
*.py[cod]
*$py.class
diff --git a/README.md b/README.md
index 72ce03412..edb2066a3 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ # MetaGPT: The Multi-Agent Framework
## News
-🚀 March. 01, 2024: Our Data Interpreter paper is on arxiv. Find all design and benchmark details [here](https://arxiv.org/abs/2402.18679)!
+🚀 Mar. 14, 2024: Our **Data Interpreter** paper is on [arxiv](https://arxiv.org/abs/2402.18679). Check the [example](https://docs.deepwisdom.ai/main/en/DataInterpreter/) and [code](https://github.com/geekan/MetaGPT/tree/main/examples/di)!
🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Data Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/di/README.md), a powerful agent capable of solving a wide range of real-world problems.
@@ -55,21 +55,30 @@ ## Software Company as Multi-Agent System
Software Company Multi-Agent Schematic (Gradually Implementing)
-## Install
+## Get Started
-### Pip installation
+### Installation
> Ensure that Python 3.9+ is installed on your system. You can check this by using: `python --version`.
> You can use conda like this: `conda create -n metagpt python=3.9 && conda activate metagpt`
```bash
-pip install metagpt
-# https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html
-metagpt --init-config # it will create ~/.metagpt/config2.yaml, just modify it to your needs
+pip install --upgrade metagpt
+# or `pip install --upgrade git+https://github.com/geekan/MetaGPT.git`
+# or `git clone https://github.com/geekan/MetaGPT && cd MetaGPT && pip install --upgrade -e .`
```
+For detailed installation guidance, please refer to [cli_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-stable-version)
+ or [docker_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-with-docker)
+
### Configuration
+You can init the config of MetaGPT by running the following command, or manually create `~/.metagpt/config2.yaml` file:
+```bash
+# Check https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html for more details
+metagpt --init-config # it will create ~/.metagpt/config2.yaml, just modify it to your needs
+```
+
You can configure `~/.metagpt/config2.yaml` according to the [example](https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml) and [doc](https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html):
```yaml
@@ -82,13 +91,13 @@ ### Configuration
### Usage
-After installation, you can use it as CLI
+After installation, you can use MetaGPT at CLI
```bash
metagpt "Create a 2048 game" # this will create a repo in ./workspace
```
-or you can use it as library
+or use it as library
```python
from metagpt.software_company import generate_repo, ProjectRepo
@@ -96,47 +105,19 @@ ### Usage
print(repo) # it will print the repo structure with files
```
-detail installation please refer to [cli_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-stable-version)
- or [docker_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-with-docker)
+You can also use its [Data Interpreter](https://github.com/geekan/MetaGPT/tree/main/examples/di)
-### Docker installation
-⏬ Step 1: Download metagpt image and prepare config2.yaml :: click to expand ::
-
+```python
+import asyncio
+from metagpt.roles.di.data_interpreter import DataInterpreter
-```bash
-docker pull metagpt/metagpt:latest
-mkdir -p /opt/metagpt/{config,workspace}
-docker run --rm metagpt/metagpt:latest cat /app/metagpt/config/config2.yaml > /opt/metagpt/config/config2.yaml
-vim /opt/metagpt/config/config2.yaml # Change the config
+async def main():
+ di = DataInterpreter()
+ await di.run("Run data analysis on sklearn Iris dataset, include a plot")
+
+asyncio.run(main()) # or await main() in a jupyter notebook setting
```
-
-
-
-⏬ Step 2: Run metagpt container :: click to expand ::
-
-
-```bash
-docker run --name metagpt -d \
- --privileged \
- -v /opt/metagpt/config/config2.yaml:/app/metagpt/config/config2.yaml \
- -v /opt/metagpt/workspace:/app/metagpt/workspace \
- metagpt/metagpt:latest
-```
-
-
-
-
-⏬ Step 3: Use metagpt :: click to expand ::
-
-
-```bash
-docker exec -it metagpt /bin/bash
-$ metagpt "Create a 2048 game" # this will create a repo in ./workspace
-```
-
-
-
### QuickStart & Demo Video
- Try it on [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT)
@@ -156,6 +137,7 @@ ## Tutorial
- 🧑💻 Contribution
- [Develop Roadmap](docs/ROADMAP.md)
- 🔖 Use Cases
+ - [Data Interpreter](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html)
- [Debate](https://docs.deepwisdom.ai/main/en/guide/use_cases/multi_agent/debate.html)
- [Researcher](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/researcher.html)
- [Recepit Assistant](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/receipt_assistant.html)
@@ -179,7 +161,9 @@ ### Contact Information
## Citation
-For now, cite the [arXiv paper](https://arxiv.org/abs/2308.00352):
+To stay updated with the latest research and development, follow [@MetaGPT_](https://twitter.com/MetaGPT_) on Twitter.
+
+To cite [MetaGPT](https://arxiv.org/abs/2308.00352) or [Data Interpreter](https://arxiv.org/abs/2402.18679) in publications, please use the following BibTeX entries.
```bibtex
@misc{hong2023metagpt,
@@ -190,4 +174,14 @@ ## Citation
archivePrefix={arXiv},
primaryClass={cs.AI}
}
+@misc{hong2024data,
+ title={Data Interpreter: An LLM Agent For Data Science},
+ author={Sirui Hong and Yizhang Lin and Bang Liu and Bangbang Liu and Binhao Wu and Danyang Li and Jiaqi Chen and Jiayi Zhang and Jinlin Wang and Li Zhang and Lingyao Zhang and Min Yang and Mingchen Zhuge and Taicheng Guo and Tuo Zhou and Wei Tao and Wenyi Wang and Xiangru Tang and Xiangtao Lu and Xiawu Zheng and Xinbing Liang and Yaying Fei and Yuheng Cheng and Zongze Xu and Chenglin Wu},
+ year={2024},
+ eprint={2402.18679},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI}
+}
+
```
+
diff --git a/SECURITY.md b/SECURITY.md
index 623e6e556..924ce5015 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -4,9 +4,9 @@ ## Supported Versions
| Version | Supported |
|---------|--------------------|
- | 7.x | :x: |
- | 6.x | :x: |
-| < 6.x | :x: |
+ | 0.7.x | :x: |
+ | 0.6.x | :x: |
+| < 0.6.x | :x: |
## Reporting a Vulnerability
diff --git a/config/config2.example.yaml b/config/config2.example.yaml
index 3a5cc3585..c5454ec32 100644
--- a/config/config2.example.yaml
+++ b/config/config2.example.yaml
@@ -4,6 +4,7 @@ llm:
api_key: "YOUR_API_KEY"
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
proxy: "YOUR_PROXY" # for LLM API requests
+ # timeout: 600 # Optional. If set to 0, default value is 300.
pricing_plan: "" # Optional. If invalid, it will be automatically filled in with the value of the `model`.
# Azure-exclusive pricing plan mappings:
# - gpt-3.5-turbo 4k: "gpt-3.5-turbo-1106"
diff --git a/docs/README_CN.md b/docs/README_CN.md
index 7a0db4974..8aea5e4cb 100644
--- a/docs/README_CN.md
+++ b/docs/README_CN.md
@@ -116,7 +116,7 @@ ### 联系信息
## 引用
-引用 [arXiv paper](https://arxiv.org/abs/2308.00352):
+如果您在研究论文中使用 MetaGPT 或 Data Interpreter,请引用我们的工作:
```bibtex
@misc{hong2023metagpt,
@@ -127,4 +127,12 @@ ## 引用
archivePrefix={arXiv},
primaryClass={cs.AI}
}
+@misc{hong2024data,
+ title={Data Interpreter: An LLM Agent For Data Science},
+ author={Sirui Hong and Yizhang Lin and Bang Liu and Bangbang Liu and Binhao Wu and Danyang Li and Jiaqi Chen and Jiayi Zhang and Jinlin Wang and Li Zhang and Lingyao Zhang and Min Yang and Mingchen Zhuge and Taicheng Guo and Tuo Zhou and Wei Tao and Wenyi Wang and Xiangru Tang and Xiangtao Lu and Xiawu Zheng and Xinbing Liang and Yaying Fei and Yuheng Cheng and Zongze Xu and Chenglin Wu},
+ year={2024},
+ eprint={2402.18679},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI}
+}
```
diff --git a/docs/README_JA.md b/docs/README_JA.md
index c6b99461c..91155532b 100644
--- a/docs/README_JA.md
+++ b/docs/README_JA.md
@@ -295,7 +295,7 @@ ## クイックスタート
## 引用
-現時点では、[arXiv 論文](https://arxiv.org/abs/2308.00352)を引用してください:
+研究論文でMetaGPTやData Interpreterを使用する場合は、以下のように当社の作業を引用してください:
```bibtex
@misc{hong2023metagpt,
@@ -306,6 +306,14 @@ ## 引用
archivePrefix={arXiv},
primaryClass={cs.AI}
}
+@misc{hong2024data,
+ title={Data Interpreter: An LLM Agent For Data Science},
+ author={Sirui Hong and Yizhang Lin and Bang Liu and Bangbang Liu and Binhao Wu and Danyang Li and Jiaqi Chen and Jiayi Zhang and Jinlin Wang and Li Zhang and Lingyao Zhang and Min Yang and Mingchen Zhuge and Taicheng Guo and Tuo Zhou and Wei Tao and Wenyi Wang and Xiangru Tang and Xiangtao Lu and Xiawu Zheng and Xinbing Liang and Yaying Fei and Yuheng Cheng and Zongze Xu and Chenglin Wu},
+ year={2024},
+ eprint={2402.18679},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI}
+}
```
## お問い合わせ先
diff --git a/examples/debate.py b/examples/debate.py
index 72ab8796d..56df16b4f 100644
--- a/examples/debate.py
+++ b/examples/debate.py
@@ -5,6 +5,7 @@ Author: garylin2099
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `send_to`
value of the `Message` object; modify the argument type of `get_by_actions`.
"""
+
import asyncio
import platform
from typing import Any
@@ -105,4 +106,4 @@ def main(idea: str, investment: float = 3.0, n_round: int = 10):
if __name__ == "__main__":
- fire.Fire(main)
+ fire.Fire(main) # run as python debate.py --idea="TOPIC" --investment=3.0 --n_round=5
diff --git a/examples/di/README.md b/examples/di/README.md
index fb4e8044b..f63795f13 100644
--- a/examples/di/README.md
+++ b/examples/di/README.md
@@ -1,7 +1,7 @@
# Data Interpreter (DI)
## What is Data Interpreter
-Data Interpreter is an agent who solves problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below.
+Data Interpreter is an agent who solves data-related problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below. For overall design and technical details, please see our [paper](https://arxiv.org/abs/2402.18679).
## Example List
- Data visualization
@@ -12,7 +12,9 @@ ## Example List
- Tool usage: web page imitation
- Tool usage: web crawling
- Tool usage: text2image
-- Tool usage: email summarization and response
+- Tool usage: email summarization and response\
- More on the way!
-Please see [here](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html) for detailed explanation.
\ No newline at end of file
+Please see the [docs](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html) for more explanation.
+
+We are continuously releasing codes, stay tuned!
diff --git a/examples/di/arxiv_reader.py b/examples/di/arxiv_reader.py
new file mode 100644
index 000000000..6e1939b81
--- /dev/null
+++ b/examples/di/arxiv_reader.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from metagpt.roles.di.data_interpreter import DataInterpreter
+
+
+async def main():
+ template = "https://arxiv.org/list/{tag}/pastweek?skip=0&show=300"
+ tags = ["cs.ai", "cs.cl", "cs.lg", "cs.se"]
+ urls = [template.format(tag=tag) for tag in tags]
+ prompt = f"""This is a collection of arxiv urls: '{urls}' .
+Record each article, remove duplicates by title (they may have multiple tags), filter out papers related to
+large language model / agent / llm, print top 100 and visualize the word count of the titles"""
+ di = DataInterpreter(react_mode="react", tools=["scrape_web_playwright"])
+
+ await di.run(prompt)
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ asyncio.run(main())
diff --git a/examples/di/crawl_webpage.py b/examples/di/crawl_webpage.py
index f06b85d9b..b8226f4f4 100644
--- a/examples/di/crawl_webpage.py
+++ b/examples/di/crawl_webpage.py
@@ -7,13 +7,31 @@
from metagpt.roles.di.data_interpreter import DataInterpreter
+PAPER_LIST_REQ = """"
+Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/,
+and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables*
+"""
+
+ECOMMERCE_REQ = """
+Get products data from website https://scrapeme.live/shop/ and save it as a csv file.
+**Notice: Firstly parse the web page encoding and the text HTML structure;
+The first page product name, price, product URL, and image URL must be saved in the csv;**
+"""
+
+NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**;
+下面是一个大致流程, 你会根据每一步的运行结果对当前计划中的任务做出适当调整:
+1. 爬取并本地保存html结构;
+2. 直接打印第7个*`快讯`*关键词后2000个字符的html内容, 作为*快讯的html内容示例*;
+3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间;
+4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。
+5. 将全部结果存在本地csv中
+"""
+
async def main():
- prompt = """Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/,
- and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables*"""
- di = DataInterpreter(use_tools=True)
+ di = DataInterpreter(tools=["scrape_web_playwright"])
- await di.run(prompt)
+ await di.run(ECOMMERCE_REQ)
if __name__ == "__main__":
diff --git a/examples/di/custom_tool.py b/examples/di/custom_tool.py
new file mode 100644
index 000000000..cbe7380c7
--- /dev/null
+++ b/examples/di/custom_tool.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2024/3/22 10:54
+@Author : alexanderwu
+@File : custom_tool.py
+"""
+
+from metagpt.roles.di.data_interpreter import DataInterpreter
+from metagpt.tools.tool_registry import register_tool
+
+
+@register_tool()
+def magic_function(arg1: str, arg2: int) -> dict:
+ """
+ The magic function that does something.
+
+ Args:
+ arg1 (str): ...
+ arg2 (int): ...
+
+ Returns:
+ dict: ...
+ """
+ return {"arg1": arg1 * 3, "arg2": arg2 * 5}
+
+
+async def main():
+ di = DataInterpreter(tools=["magic_function"])
+ await di.run("Just call the magic function with arg1 'A' and arg2 2. Tell me the result.")
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ asyncio.run(main())
diff --git a/examples/di/data_visualization.py b/examples/di/data_visualization.py
index 9af72dc42..184e04f26 100644
--- a/examples/di/data_visualization.py
+++ b/examples/di/data_visualization.py
@@ -1,14 +1,17 @@
import asyncio
+from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
+from metagpt.utils.recovery_util import save_history
async def main(requirement: str = ""):
- di = DataInterpreter(use_tools=False)
- await di.run(requirement)
+ di = DataInterpreter()
+ rsp = await di.run(requirement)
+ logger.info(rsp)
+ save_history(role=di)
if __name__ == "__main__":
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
-
asyncio.run(main(requirement))
diff --git a/examples/di/email_summary.py b/examples/di/email_summary.py
index af081fee2..7c112767c 100644
--- a/examples/di/email_summary.py
+++ b/examples/di/email_summary.py
@@ -22,7 +22,7 @@ async def main():
Firstly, Please help me fetch the latest 5 senders and full letter contents.
Then, summarize each of the 5 emails into one sentence (you can do this by yourself, no need to import other models to do this) and output them in a markdown format."""
- di = DataInterpreter(use_tools=True)
+ di = DataInterpreter()
await di.run(prompt)
diff --git a/examples/di/imitate_webpage.py b/examples/di/imitate_webpage.py
index e2c99b874..60ebab389 100644
--- a/examples/di/imitate_webpage.py
+++ b/examples/di/imitate_webpage.py
@@ -12,10 +12,9 @@ async def main():
web_url = "https://pytorch.org/"
prompt = f"""This is a URL of webpage: '{web_url}' .
Firstly, utilize Selenium and WebDriver for rendering.
-Secondly, convert image to a webpage including HTML, CSS and JS in one go.
-Finally, save webpage in a text file.
+Secondly, convert image to a webpage including HTML, CSS and JS in one go.
Note: All required dependencies and environments have been fully installed and configured."""
- di = DataInterpreter(use_tools=True)
+ di = DataInterpreter(tools=["GPTvGenerator"])
await di.run(prompt)
diff --git a/examples/di/machine_learning.py b/examples/di/machine_learning.py
index a58735831..c674e66e8 100644
--- a/examples/di/machine_learning.py
+++ b/examples/di/machine_learning.py
@@ -2,11 +2,21 @@ import fire
from metagpt.roles.di.data_interpreter import DataInterpreter
+WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
-async def main(auto_run: bool = True):
- requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
- di = DataInterpreter(auto_run=auto_run)
- await di.run(requirement)
+DATA_DIR = "path/to/your/data"
+# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data
+SALES_FORECAST_REQ = f"""Train a model to predict sales for each department in every store (split the last 40 weeks records as validation dataset, the others is train dataset), include plot total sales trends, print metric and plot scatter plots of
+groud truth and predictions on validation data. Dataset is {DATA_DIR}/train.csv, the metric is weighted mean absolute error (WMAE) for test data. Notice: *print* key variables to get more information for next task step.
+"""
+
+REQUIREMENTS = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ}
+
+
+async def main(use_case: str = "wine"):
+ mi = DataInterpreter()
+ requirement = REQUIREMENTS[use_case]
+ await mi.run(requirement)
if __name__ == "__main__":
diff --git a/examples/di/machine_learning_with_tools.py b/examples/di/machine_learning_with_tools.py
new file mode 100644
index 000000000..291e734c8
--- /dev/null
+++ b/examples/di/machine_learning_with_tools.py
@@ -0,0 +1,16 @@
+import asyncio
+
+from metagpt.roles.di.data_interpreter import DataInterpreter
+
+
+async def main(requirement: str):
+ role = DataInterpreter(use_reflection=True, tools=[""])
+ await role.run(requirement)
+
+
+if __name__ == "__main__":
+ data_path = "your/path/to/titanic"
+ train_path = f"{data_path}/split_train.csv"
+ eval_path = f"{data_path}/split_eval.csv"
+ requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{train_path}', eval data path: '{eval_path}'."
+ asyncio.run(main(requirement))
diff --git a/examples/di/ml_engineer_with_tools.py b/examples/di/ml_engineer_with_tools.py
deleted file mode 100644
index 6119ad843..000000000
--- a/examples/di/ml_engineer_with_tools.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import asyncio
-
-from metagpt.roles.di.ml_engineer import MLEngineer
-
-
-async def main(requirement: str):
- role = MLEngineer(auto_run=True, use_tools=True)
- await role.run(requirement)
-
-
-if __name__ == "__main__":
- data_path = "your_path_to_icr/icr-identify-age-related-conditions"
- train_path = f"{data_path}/your_train_data.csv"
- eval_path = f"{data_path}/your_eval_data.csv"
- requirement = f"This is a medical dataset with over fifty anonymized health characteristics linked to three age-related conditions. Your goal is to predict whether a subject has or has not been diagnosed with one of these conditions.The target column is Class. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report f1 score on the eval data. Train data path: {train_path}, eval data path:{eval_path}."
- asyncio.run(main(requirement))
diff --git a/examples/di/rm_image_background.py b/examples/di/rm_image_background.py
index b74a79eeb..cb7900a0a 100644
--- a/examples/di/rm_image_background.py
+++ b/examples/di/rm_image_background.py
@@ -4,7 +4,7 @@ from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(requirement: str = ""):
- di = DataInterpreter(use_tools=False)
+ di = DataInterpreter()
await di.run(requirement)
diff --git a/examples/di/sd_tool_usage.py b/examples/di/sd_tool_usage.py
index 69c7df5bd..b373a6251 100644
--- a/examples/di/sd_tool_usage.py
+++ b/examples/di/sd_tool_usage.py
@@ -8,7 +8,7 @@ from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(requirement: str = ""):
- di = DataInterpreter(use_tools=True, goal=requirement)
+ di = DataInterpreter(tools=["SDEngine"])
await di.run(requirement)
diff --git a/examples/di/solve_math_problems.py b/examples/di/solve_math_problems.py
index 35a69c953..f7fd3d4e3 100644
--- a/examples/di/solve_math_problems.py
+++ b/examples/di/solve_math_problems.py
@@ -4,10 +4,11 @@ from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(requirement: str = ""):
- di = DataInterpreter(use_tools=False)
+ di = DataInterpreter()
await di.run(requirement)
if __name__ == "__main__":
requirement = "Solve this math problem: The greatest common divisor of positive integers m and n is 6. The least common multiple of m and n is 126. What is the least possible value of m + n?"
+ # answer: 60 (m = 18, n = 42)
asyncio.run(main(requirement))
diff --git a/examples/reverse_engineering.py b/examples/reverse_engineering.py
new file mode 100644
index 000000000..f80fc09e6
--- /dev/null
+++ b/examples/reverse_engineering.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import asyncio
+import shutil
+from pathlib import Path
+
+import typer
+
+from metagpt.actions.rebuild_class_view import RebuildClassView
+from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
+from metagpt.context import Context
+from metagpt.llm import LLM
+from metagpt.logs import logger
+from metagpt.utils.git_repository import GitRepository
+from metagpt.utils.project_repo import ProjectRepo
+
+app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
+
+
+@app.command("", help="Python project reverse engineering.")
+def startup(
+ project_root: str = typer.Argument(
+ default="",
+ help="Specify the root directory of the existing project for reverse engineering.",
+ ),
+ output_dir: str = typer.Option(default="", help="Specify the output directory path for reverse engineering."),
+):
+ package_root = Path(project_root)
+ if not package_root.exists():
+ raise FileNotFoundError(f"{project_root} not exists")
+ if not _is_python_package_root(package_root):
+ raise FileNotFoundError(f'There are no "*.py" files under "{project_root}".')
+ init_file = package_root / "__init__.py" # used by pyreverse
+ init_file_exists = init_file.exists()
+ if not init_file_exists:
+ init_file.touch()
+
+ if not output_dir:
+ output_dir = package_root / "../reverse_engineering_output"
+ logger.info(f"output dir:{output_dir}")
+ try:
+ asyncio.run(reverse_engineering(package_root, Path(output_dir)))
+ finally:
+ if not init_file_exists:
+ init_file.unlink(missing_ok=True)
+ tmp_dir = package_root / "__dot__"
+ if tmp_dir.exists():
+ shutil.rmtree(tmp_dir, ignore_errors=True)
+
+
+def _is_python_package_root(package_root: Path) -> bool:
+ for file_path in package_root.iterdir():
+ if file_path.is_file():
+ if file_path.suffix == ".py":
+ return True
+ return False
+
+
+async def reverse_engineering(package_root: Path, output_dir: Path):
+ ctx = Context()
+ ctx.git_repo = GitRepository(output_dir)
+ ctx.repo = ProjectRepo(ctx.git_repo)
+ action = RebuildClassView(name="ReverseEngineering", i_context=str(package_root), llm=LLM(), context=ctx)
+ await action.run()
+
+ action = RebuildSequenceView(name="ReverseEngineering", llm=LLM(), context=ctx)
+ await action.run()
+
+
+if __name__ == "__main__":
+ app()
diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py
index 97b1378ee..276431ed8 100644
--- a/examples/search_with_specific_engine.py
+++ b/examples/search_with_specific_engine.py
@@ -4,21 +4,17 @@
"""
import asyncio
+from metagpt.config2 import Config
from metagpt.roles import Searcher
-from metagpt.tools.search_engine import SearchEngine, SearchEngineType
+from metagpt.tools.search_engine import SearchEngine
async def main():
question = "What are the most interesting human facts?"
- kwargs = {"api_key": "", "cse_id": "", "proxy": None}
- # Serper API
- # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPER_GOOGLE, **kwargs)).run(question)
- # SerpAPI
- # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPAPI_GOOGLE, **kwargs)).run(question)
- # Google API
- # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DIRECT_GOOGLE, **kwargs)).run(question)
- # DDG API
- await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO, **kwargs)).run(question)
+
+ search = Config.default().search
+ kwargs = search.model_dump()
+ await Searcher(search_engine=SearchEngine(engine=search.api_type, **kwargs)).run(question)
if __name__ == "__main__":
diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py
index 29af1dad1..495ed4031 100644
--- a/metagpt/actions/__init__.py
+++ b/metagpt/actions/__init__.py
@@ -23,7 +23,7 @@ from metagpt.actions.write_prd import WritePRD
from metagpt.actions.write_prd_review import WritePRDReview
from metagpt.actions.write_test import WriteTest
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
-from metagpt.actions.di.write_analysis_code import WriteCodeWithoutTools, WriteCodeWithTools
+from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.actions.di.write_plan import WritePlan
@@ -46,8 +46,7 @@ class ActionType(Enum):
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
CONDUCT_RESEARCH = ConductResearch
EXECUTE_NB_CODE = ExecuteNbCode
- WRITE_CODE_WITHOUT_TOOLS = WriteCodeWithoutTools
- WRITE_CODE_WITH_TOOLS = WriteCodeWithTools
+ WRITE_ANALYSIS_CODE = WriteAnalysisCode
WRITE_PLAN = WritePlan
diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py
index 09da4a988..31e4cc0fc 100644
--- a/metagpt/actions/action_node.py
+++ b/metagpt/actions/action_node.py
@@ -17,6 +17,7 @@ from pydantic import BaseModel, Field, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
@@ -330,7 +331,7 @@ class ActionNode:
def compile_to(self, i: Dict, schema, kv_sep) -> str:
if schema == "json":
- return json.dumps(i, indent=4)
+ return json.dumps(i, indent=4, ensure_ascii=False)
elif schema == "markdown":
return dict_to_markdown(i, kv_sep=kv_sep)
else:
@@ -339,10 +340,7 @@ class ActionNode:
def tagging(self, text, schema, tag="") -> str:
if not tag:
return text
- if schema == "json":
- return f"[{tag}]\n" + text + f"\n[/{tag}]"
- else: # markdown
- return f"[{tag}]\n" + text + f"\n[/{tag}]"
+ return f"[{tag}]\n{text}\n[/{tag}]"
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
@@ -374,7 +372,7 @@ class ActionNode:
schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action
"""
if schema == "raw":
- return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
+ return f"{context}\n\n## Actions\n{LANGUAGE_CONSTRAINT}\n{self.instruction}"
### 直接使用 pydantic BaseModel 生成 instruction 与 example,仅限 JSON
# child_class = self._create_children_class()
@@ -416,7 +414,7 @@ class ActionNode:
images: Optional[Union[str, list[str]]] = None,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
- timeout=3,
+ timeout=USE_CONFIG_TIMEOUT,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout)
@@ -448,7 +446,9 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
- async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None):
+ async def simple_fill(
+ self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=None
+ ):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
@@ -473,7 +473,7 @@ class ActionNode:
mode="auto",
strgy="simple",
images: Optional[Union[str, list[str]]] = None,
- timeout=3,
+ timeout=USE_CONFIG_TIMEOUT,
exclude=[],
):
"""Fill the node(s) with mode.
diff --git a/metagpt/actions/di/debug_code.py b/metagpt/actions/di/debug_code.py
deleted file mode 100644
index f589436f5..000000000
--- a/metagpt/actions/di/debug_code.py
+++ /dev/null
@@ -1,109 +0,0 @@
-from __future__ import annotations
-
-from metagpt.actions.di.write_analysis_code import BaseWriteAnalysisCode
-from metagpt.logs import logger
-from metagpt.schema import Message
-from metagpt.utils.common import create_func_call_config
-
-DEBUG_REFLECTION_EXAMPLE = '''
-Example 1:
-[previous impl]:
-```python
-def add(a: int, b: int) -> int:
- """
- Given integers a and b, return the total value of a and b.
- """
- return a - b
-```
-
-[runtime Error]:
-Tested passed:
-
-Tests failed:
-assert add(1, 2) == 3 # output: -1
-assert add(1, 2) == 4 # output: -1
-
-[reflection on previous impl]:
-The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input.
-
-[improved impl]:
-```python
-def add(a: int, b: int) -> int:
- """
- Given integers a and b, return the total value of a and b.
- """
- return a + b
-```
-'''
-
-REFLECTION_PROMPT = """
-Here is an example for you.
-{debug_example}
-[context]
-{context}
-
-[previous impl]
-{code}
-[runtime Error]
-{runtime_result}
-
-Analysis the error step by step, provide me improve method and code. Remember to follow [context] requirement. Don't forget write code for steps behind the error step.
-[reflection on previous impl]:
-xxx
-"""
-
-CODE_REFLECTION = {
- "name": "execute_reflection_code",
- "description": "Execute reflection code.",
- "parameters": {
- "type": "object",
- "properties": {
- "reflection": {
- "type": "string",
- "description": "Reflection on previous impl.",
- },
- "improved_impl": {
- "type": "string",
- "description": "Refined code after reflection.",
- },
- },
- "required": ["reflection", "improved_impl"],
- },
-}
-
-
-class DebugCode(BaseWriteAnalysisCode):
- async def run(
- self,
- context: list[Message] = None,
- code: str = "",
- runtime_result: str = "",
- ) -> str:
- """
- Execute the debugging process based on the provided context, code, and runtime_result.
-
- Args:
- context (list[Message]): A list of Message objects representing the context.
- code (str): The code to be debugged.
- runtime_result (str): The result of the code execution.
-
- Returns:
- str: The improved implementation based on the debugging process.
- """
-
- info = []
- reflection_prompt = REFLECTION_PROMPT.format(
- debug_example=DEBUG_REFLECTION_EXAMPLE,
- context=context,
- code=code,
- runtime_result=runtime_result,
- )
- system_prompt = "You are an AI Python assistant. You will be given your previous implementation code of a task, runtime error results, and a hint to change the implementation appropriately. Write your full implementation "
- info.append(Message(role="system", content=system_prompt))
- info.append(Message(role="user", content=reflection_prompt))
-
- tool_config = create_func_call_config(CODE_REFLECTION)
- reflection = await self.llm.aask_code(messages=info, **tool_config)
- logger.info(f"reflection is {reflection}")
-
- return {"code": reflection["improved_impl"]}
diff --git a/metagpt/actions/di/execute_nb_code.py b/metagpt/actions/di/execute_nb_code.py
index 0371d556c..0cf16b70f 100644
--- a/metagpt/actions/di/execute_nb_code.py
+++ b/metagpt/actions/di/execute_nb_code.py
@@ -9,7 +9,6 @@ from __future__ import annotations
import asyncio
import base64
import re
-import traceback
from typing import Literal, Tuple
import nbformat
@@ -58,7 +57,23 @@ class ExecuteNbCode(Action):
async def terminate(self):
"""kill NotebookClient"""
- await self.nb_client._async_cleanup_kernel()
+ if self.nb_client.km is not None and await self.nb_client.km.is_alive():
+ await self.nb_client.km.shutdown_kernel(now=True)
+ await self.nb_client.km.cleanup_resources()
+
+ channels = [
+ self.nb_client.kc.stdin_channel, # The channel for handling standard input to the kernel.
+ self.nb_client.kc.hb_channel, # The channel for heartbeat communication between the kernel and client.
+ self.nb_client.kc.control_channel, # The channel for controlling the kernel.
+ ]
+
+ # Stops all the running channels for this kernel
+ for channel in channels:
+ if channel.is_alive():
+ channel.stop()
+
+ self.nb_client.kc = None
+ self.nb_client.km = None
async def reset(self):
"""reset NotebookClient"""
@@ -91,17 +106,17 @@ class ExecuteNbCode(Action):
else:
cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output)))
- def parse_outputs(self, outputs: list[str]) -> str:
+ def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, str]:
"""Parses the outputs received from notebook execution."""
assert isinstance(outputs, list)
- parsed_output = ""
-
+ parsed_output, is_success = [], True
for i, output in enumerate(outputs):
+ output_text = ""
if output["output_type"] == "stream" and not any(
tag in output["text"]
for tag in ["| INFO | metagpt", "| ERROR | metagpt", "| WARNING | metagpt", "DEBUG"]
):
- parsed_output += output["text"]
+ output_text = output["text"]
elif output["output_type"] == "display_data":
if "image/png" in output["data"]:
self.show_bytes_figure(output["data"]["image/png"], self.interaction)
@@ -110,8 +125,22 @@ class ExecuteNbCode(Action):
f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..."
)
elif output["output_type"] == "execute_result":
- parsed_output += output["data"]["text/plain"]
- return parsed_output
+ output_text = output["data"]["text/plain"]
+ elif output["output_type"] == "error":
+ output_text, is_success = "\n".join(output["traceback"]), False
+
+ # handle coroutines that are not executed asynchronously
+ if output_text.strip().startswith(" Tuple[str, bool]:
"""
@@ -173,14 +202,7 @@ class ExecuteNbCode(Action):
# run code
cell_index = len(self.nb.cells) - 1
- success, error_message = await self.run_cell(self.nb.cells[-1], cell_index)
-
- if not success:
- return truncate(remove_escape_and_color_codes(error_message), is_success=success)
-
- # code success
- outputs = self.parse_outputs(self.nb.cells[-1].outputs)
- outputs, success = truncate(remove_escape_and_color_codes(outputs), is_success=success)
+ success, outputs = await self.run_cell(self.nb.cells[-1], cell_index)
if "!pip" in code:
success = False
@@ -196,54 +218,39 @@ class ExecuteNbCode(Action):
raise ValueError(f"Only support for language: python, markdown, but got {language}, ")
-def truncate(result: str, keep_len: int = 2000, is_success: bool = True):
- """对于超出keep_len个字符的result: 执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。"""
- if is_success:
- desc = f"Executed code successfully. Truncated to show only first {keep_len} characters\n"
- else:
- desc = f"Executed code failed, please reflect the cause of bug and then debug. Truncated to show only last {keep_len} characters\n"
-
- if result.strip().startswith(" keep_len:
- result = result[-keep_len:] if not is_success else result[:keep_len]
- return desc + result, is_success
-
- return result, is_success
-
-
def remove_escape_and_color_codes(input_str: str):
- # 使用正则表达式去除转义字符和颜色代码
+ # 使用正则表达式去除jupyter notebook输出结果中的转义字符和颜色代码
+ # Use regular expressions to get rid of escape characters and color codes in jupyter notebook output.
pattern = re.compile(r"\x1b\[[0-9;]*[mK]")
result = pattern.sub("", input_str)
return result
def display_markdown(content: str):
- # 使用正则表达式逐个匹配代码块
+ # Use regular expressions to match blocks of code one by one.
matches = re.finditer(r"```(.+?)```", content, re.DOTALL)
start_index = 0
content_panels = []
- # 逐个打印匹配到的文本和代码
+ # Set the text background color and text color.
+ style = "black on white"
+ # Print the matching text and code one by one.
for match in matches:
text_content = content[start_index : match.start()].strip()
code_content = match.group(0).strip()[3:-3] # Remove triple backticks
if text_content:
- content_panels.append(Panel(Markdown(text_content), box=MINIMAL))
+ content_panels.append(Panel(Markdown(text_content), style=style, box=MINIMAL))
if code_content:
- content_panels.append(Panel(Markdown(f"```{code_content}"), box=MINIMAL))
+ content_panels.append(Panel(Markdown(f"```{code_content}"), style=style, box=MINIMAL))
start_index = match.end()
- # 打印剩余文本(如果有)
+ # Print remaining text (if any).
remaining_text = content[start_index:].strip()
if remaining_text:
- content_panels.append(Panel(Markdown(remaining_text), box=MINIMAL))
+ content_panels.append(Panel(Markdown(remaining_text), style=style, box=MINIMAL))
- # 在Live模式中显示所有Panel
+ # Display all panels in Live mode.
with Live(auto_refresh=False, console=Console(), vertical_overflow="visible") as live:
live.update(Group(*content_panels))
live.refresh()
diff --git a/metagpt/actions/di/ml_action.py b/metagpt/actions/di/ml_action.py
deleted file mode 100644
index d49b7b67d..000000000
--- a/metagpt/actions/di/ml_action.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from __future__ import annotations
-
-from typing import Tuple
-
-from metagpt.actions import Action
-from metagpt.actions.di.write_analysis_code import WriteCodeWithTools
-from metagpt.prompts.di.ml_action import (
- ML_GENERATE_CODE_PROMPT,
- ML_TOOL_USAGE_PROMPT,
- PRINT_DATA_COLUMNS,
- UPDATE_DATA_COLUMNS,
-)
-from metagpt.prompts.di.write_analysis_code import CODE_GENERATOR_WITH_TOOLS
-from metagpt.schema import Message, Plan
-from metagpt.utils.common import create_func_call_config, remove_comments
-
-
-class WriteCodeWithToolsML(WriteCodeWithTools):
- async def run(
- self,
- context: list[Message],
- plan: Plan = None,
- column_info: str = "",
- **kwargs,
- ) -> Tuple[list[Message], str]:
- # prepare tool schemas and tool-type-specific instruction
- tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
-
- # ML-specific variables to be used in prompt
- finished_tasks = plan.get_finished_tasks()
- code_context = [remove_comments(task.code) for task in finished_tasks]
- code_context = "\n\n".join(code_context)
-
- # prepare prompt depending on tool availability & LLM call
- if tool_schemas:
- prompt = ML_TOOL_USAGE_PROMPT.format(
- user_requirement=plan.goal,
- history_code=code_context,
- current_task=plan.current_task.instruction,
- column_info=column_info,
- tool_type_usage_prompt=tool_type_usage_prompt,
- tool_schemas=tool_schemas,
- )
-
- else:
- prompt = ML_GENERATE_CODE_PROMPT.format(
- user_requirement=plan.goal,
- history_code=code_context,
- current_task=plan.current_task.instruction,
- column_info=column_info,
- tool_type_usage_prompt=tool_type_usage_prompt,
- )
- tool_config = create_func_call_config(CODE_GENERATOR_WITH_TOOLS)
- rsp = await self.llm.aask_code(prompt, **tool_config)
-
- # Extra output to be used for potential debugging
- context = [Message(content=prompt, role="user")]
-
- return context, rsp
-
-
-class UpdateDataColumns(Action):
- async def run(self, plan: Plan = None) -> dict:
- finished_tasks = plan.get_finished_tasks()
- code_context = [remove_comments(task.code) for task in finished_tasks]
- code_context = "\n\n".join(code_context)
- prompt = UPDATE_DATA_COLUMNS.format(history_code=code_context)
- tool_config = create_func_call_config(PRINT_DATA_COLUMNS)
- rsp = await self.llm.aask_code(prompt, **tool_config)
- return rsp
diff --git a/metagpt/actions/di/write_analysis_code.py b/metagpt/actions/di/write_analysis_code.py
index 0c4980c2b..711e56d39 100644
--- a/metagpt/actions/di/write_analysis_code.py
+++ b/metagpt/actions/di/write_analysis_code.py
@@ -6,150 +6,68 @@
"""
from __future__ import annotations
-from typing import Tuple
+import json
from metagpt.actions import Action
-from metagpt.logs import logger
from metagpt.prompts.di.write_analysis_code import (
- CODE_GENERATOR_WITH_TOOLS,
- SELECT_FUNCTION_TOOLS,
- TOOL_RECOMMENDATION_PROMPT,
- TOOL_USAGE_PROMPT,
+ CHECK_DATA_PROMPT,
+ DEBUG_REFLECTION_EXAMPLE,
+ INTERPRETER_SYSTEM_MSG,
+ REFLECTION_PROMPT,
+ REFLECTION_SYSTEM_MSG,
+ STRUCTUAL_PROMPT,
)
-from metagpt.schema import Message, Plan, SystemMessage
-from metagpt.tools import TOOL_REGISTRY
-from metagpt.tools.tool_registry import validate_tool_names
-from metagpt.utils.common import create_func_call_config
+from metagpt.schema import Message, Plan
+from metagpt.utils.common import CodeParser, remove_comments
-class BaseWriteAnalysisCode(Action):
- DEFAULT_SYSTEM_MSG: str = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: The code for the next step depends on the code for the previous step. Must reuse variables in the lastest other code directly, dont creat it again, it is very import for you. Use !pip install in a standalone block to install missing packages.Usually the libraries you need are already installed.Dont check if packages already imported.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
- # REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous tasks in your current code block, include new codes only, DONT repeat codes!"""
-
- def insert_system_message(self, context: list[Message], system_msg: str = None):
- system_msg = system_msg or self.DEFAULT_SYSTEM_MSG
- context.insert(0, SystemMessage(content=system_msg)) if context[0].role != "system" else None
- return context
-
- async def run(self, context: list[Message], plan: Plan = None) -> dict:
- """Run of a code writing action, used in data analysis or modeling
-
- Args:
- context (list[Message]): Action output history, source action denoted by Message.cause_by
- plan (Plan, optional): Overall plan. Defaults to None.
-
- Returns:
- dict: code result in the format of {"code": "print('hello world')", "language": "python"}
- """
- raise NotImplementedError
-
-
-class WriteCodeWithoutTools(BaseWriteAnalysisCode):
- """Ask LLM to generate codes purely by itself without local user-defined tools"""
-
- async def run(self, context: list[Message], plan: Plan = None, system_msg: str = None, **kwargs) -> dict:
- messages = self.insert_system_message(context, system_msg)
- rsp = await self.llm.aask_code(messages, **kwargs)
- return rsp
-
-
-class WriteCodeWithTools(BaseWriteAnalysisCode):
- """Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
-
- # selected tools to choose from, listed by their names. An empty list means selection from all tools.
- selected_tools: list[str] = []
-
- def _get_tools_by_type(self, tool_type: str) -> dict:
- """
- Retreive tools by tool type from registry, but filtered by pre-selected tool list
-
- Args:
- tool_type (str): Tool type to retrieve from the registry
-
- Returns:
- dict: A dict of tool name to Tool object, representing available tools under the type
- """
- candidate_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
- if self.selected_tools:
- candidate_tool_names = set(self.selected_tools) & candidate_tools.keys()
- candidate_tools = {tool_name: candidate_tools[tool_name] for tool_name in candidate_tool_names}
- return candidate_tools
-
- async def _recommend_tool(
- self,
- task: str,
- available_tools: dict,
- ) -> dict:
- """
- Recommend tools for the specified task.
-
- Args:
- task (str): the task to recommend tools for
- available_tools (dict): the available tools description
-
- Returns:
- dict: schemas of recommended tools for the specified task
- """
- prompt = TOOL_RECOMMENDATION_PROMPT.format(
- current_task=task,
- available_tools=available_tools,
- )
- tool_config = create_func_call_config(SELECT_FUNCTION_TOOLS)
- rsp = await self.llm.aask_code(prompt, **tool_config)
- recommend_tools = rsp["recommend_tools"]
- logger.info(f"Recommended tools: \n{recommend_tools}")
-
- # Parses and validates the recommended tools, for LLM might hallucinate and recommend non-existing tools
- valid_tools = validate_tool_names(recommend_tools, return_tool_object=True)
-
- tool_schemas = {tool.name: tool.schemas for tool in valid_tools}
-
- return tool_schemas
-
- async def _prepare_tools(self, plan: Plan) -> Tuple[dict, str]:
- """Prepare tool schemas and usage instructions according to current task
-
- Args:
- plan (Plan): The overall plan containing task information.
-
- Returns:
- Tuple[dict, str]: A tool schemas ({tool_name: tool_schema_dict}) and a usage prompt for the type of tools selected
- """
- # find tool type from task type through exact match, can extend to retrieval in the future
- tool_type = plan.current_task.task_type
-
- # prepare tool-type-specific instruction
- tool_type_usage_prompt = (
- TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
+class WriteAnalysisCode(Action):
+ async def _debug_with_reflection(self, context: list[Message], working_memory: list[Message]):
+ reflection_prompt = REFLECTION_PROMPT.format(
+ debug_example=DEBUG_REFLECTION_EXAMPLE,
+ context=context,
+ previous_impl=working_memory,
)
- # prepare schemas of available tools
- tool_schemas = {}
- available_tools = self._get_tools_by_type(tool_type)
- if available_tools:
- available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
- tool_schemas = await self._recommend_tool(plan.current_task.instruction, available_tools)
+ rsp = await self._aask(reflection_prompt, system_msgs=[REFLECTION_SYSTEM_MSG])
+ reflection = json.loads(CodeParser.parse_code(block=None, text=rsp))
- return tool_schemas, tool_type_usage_prompt
+ return reflection["improved_impl"]
async def run(
self,
- context: list[Message],
- plan: Plan,
+ user_requirement: str,
+ plan_status: str = "",
+ tool_info: str = "",
+ working_memory: list[Message] = None,
+ use_reflection: bool = False,
**kwargs,
) -> str:
- # prepare tool schemas and tool-type-specific instruction
- tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
-
- # form a complete tool usage instruction and include it as a message in context
- tools_instruction = TOOL_USAGE_PROMPT.format(
- tool_schemas=tool_schemas, tool_type_usage_prompt=tool_type_usage_prompt
+ structual_prompt = STRUCTUAL_PROMPT.format(
+ user_requirement=user_requirement,
+ plan_status=plan_status,
+ tool_info=tool_info,
)
- context.append(Message(content=tools_instruction, role="user"))
- # prepare prompt & LLM call
- prompt = self.insert_system_message(context)
- tool_config = create_func_call_config(CODE_GENERATOR_WITH_TOOLS)
- rsp = await self.llm.aask_code(prompt, **tool_config)
+ working_memory = working_memory or []
+ context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory)
- return rsp
+ # LLM call
+ if use_reflection:
+ code = await self._debug_with_reflection(context=context, working_memory=working_memory)
+ else:
+ rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs)
+ code = CodeParser.parse_code(block=None, text=rsp)
+
+ return code
+
+
+class CheckData(Action):
+ async def run(self, plan: Plan) -> dict:
+ finished_tasks = plan.get_finished_tasks()
+ code_written = [remove_comments(task.code) for task in finished_tasks]
+ code_written = "\n\n".join(code_written)
+ prompt = CHECK_DATA_PROMPT.format(code_written=code_written)
+ rsp = await self._aask(prompt)
+ code = CodeParser.parse_code(block=None, text=rsp)
+ return code
diff --git a/metagpt/actions/di/write_plan.py b/metagpt/actions/di/write_plan.py
index 518dfb6c6..2dbe3f0e7 100644
--- a/metagpt/actions/di/write_plan.py
+++ b/metagpt/actions/di/write_plan.py
@@ -12,81 +12,49 @@ from typing import Tuple
from metagpt.actions import Action
from metagpt.logs import logger
-from metagpt.prompts.di.write_analysis_code import (
- ASSIGN_TASK_TYPE_CONFIG,
- ASSIGN_TASK_TYPE_PROMPT,
-)
from metagpt.schema import Message, Plan, Task
-from metagpt.tools import TOOL_REGISTRY
-from metagpt.utils.common import CodeParser, create_func_call_config
+from metagpt.strategy.task_type import TaskType
+from metagpt.utils.common import CodeParser
class WritePlan(Action):
PROMPT_TEMPLATE: str = """
# Context:
- __context__
+ {context}
+ # Available Task Types:
+ {task_type_desc}
# Task:
- Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks.
+ Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to {max_tasks} tasks.
If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan.
If you encounter errors on the current task, revise and output the current single task only.
Output a list of jsons following the format:
```json
[
- {
+ {{
"task_id": str = "unique identifier for a task in plan, can be an ordinal",
"dependent_task_ids": list[str] = "ids of tasks prerequisite to this task",
"instruction": "what you should do in this task, one short phrase or sentence",
- },
+ "task_type": "type of this task, should be one of Available Task Types",
+ }},
...
]
```
"""
- async def assign_task_type(self, tasks: list[dict]) -> str:
- """Assign task type to each task in tasks
-
- Args:
- tasks (list[dict]): tasks to be assigned task type
-
- Returns:
- str: tasks with task type assigned in a json string
- """
- task_info = "\n".join([f"Task {task['task_id']}: {task['instruction']}" for task in tasks])
- task_type_desc = "\n".join(
- [f"- **{tool_type.name}**: {tool_type.desc}" for tool_type in TOOL_REGISTRY.get_tool_types().values()]
- ) # task type are binded with tool type now, should be improved in the future
- prompt = ASSIGN_TASK_TYPE_PROMPT.format(
- task_info=task_info, task_type_desc=task_type_desc
- ) # task types are set to be the same as tool types, for now
- tool_config = create_func_call_config(ASSIGN_TASK_TYPE_CONFIG)
- rsp = await self.llm.aask_code(prompt, **tool_config)
- task_type_list = rsp["task_type"]
- logger.info(f"assigned task types: {task_type_list}")
- for task, task_type in zip(tasks, task_type_list):
- task["task_type"] = task_type
- return json.dumps(tasks)
-
- async def run(self, context: list[Message], max_tasks: int = 5, use_tools: bool = False) -> str:
- prompt = (
- self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context]))
- # .replace("__current_plan__", current_plan)
- .replace("__max_tasks__", str(max_tasks))
+ async def run(self, context: list[Message], max_tasks: int = 5) -> str:
+ task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType])
+ prompt = self.PROMPT_TEMPLATE.format(
+ context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc
)
rsp = await self._aask(prompt)
rsp = CodeParser.parse_code(block=None, text=rsp)
- if use_tools:
- rsp = await self.assign_task_type(json.loads(rsp))
return rsp
-def rsp_to_tasks(rsp: str) -> list[Task]:
+def update_plan_from_rsp(rsp: str, current_plan: Plan):
rsp = json.loads(rsp)
tasks = [Task(**task_config) for task_config in rsp]
- return tasks
-
-def update_plan_from_rsp(rsp: str, current_plan: Plan):
- tasks = rsp_to_tasks(rsp)
if len(tasks) == 1 or tasks[0].dependent_task_ids:
if tasks[0].dependent_task_ids and len(tasks) > 1:
# tasks[0].dependent_task_ids means the generated tasks are not a complete plan
diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py
index 6dd5690b6..ff030ec87 100644
--- a/metagpt/actions/rebuild_class_view.py
+++ b/metagpt/actions/rebuild_class_view.py
@@ -76,7 +76,7 @@ class RebuildClassView(Action):
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
- filename = str(pathname.with_suffix(".mmd"))
+ filename = str(pathname.with_suffix(".class_diagram.mmd"))
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)
diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py
index 227d29872..0e67de908 100644
--- a/metagpt/actions/rebuild_sequence_view.py
+++ b/metagpt/actions/rebuild_sequence_view.py
@@ -12,7 +12,7 @@ from __future__ import annotations
import re
from datetime import datetime
from pathlib import Path
-from typing import List, Optional
+from typing import List, Optional, Set
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -125,7 +125,7 @@ class RebuildSequenceView(Action):
if prefix in r.subject:
classes.append(r)
await self._rebuild_use_case(r.subject)
- participants = set()
+ participants = await self._search_participants(split_namespace(entry.subject)[0])
class_details = []
class_views = []
for c in classes:
@@ -171,7 +171,8 @@ class RebuildSequenceView(Action):
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
- await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
+ if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW:
+ await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
@@ -184,7 +185,7 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
)
- await self.graph_db.save()
+ await self._save_sequence_view(subject=entry.subject, content=sequence_view)
async def _merge_sequence_view(self, entry: SPO) -> bool:
"""
@@ -267,38 +268,6 @@ class RebuildSequenceView(Action):
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
- # class _UseCase(BaseModel):
- # description: str = Field(default="...", description="Describes about what the use case to do")
- # inputs: List[str] = Field(default=["input name 1", "input name 2"],
- # description="Lists the input names of the use case from external sources")
- # outputs: List[str] = Field(default=["output name 1", "output name 2"],
- # description="Lists the output names of the use case to external sources")
- # actors: List[str] = Field(default=["actor name 1", "actor name 2"],
- # description="Lists the participant actors of the use case")
- # steps: List[str] = Field(default=["Step 1", "Step 2"],
- # description="Lists the steps about how the use case works step by step")
- # reason: str = Field(default="Because ...",
- # description="Explaining under what circumstances would the external system execute this use case.")
- #
- #
- # class _UseCaseList(BaseModel):
- # description: str = Field(default="...",
- # description="A summary explains what the whole source code want to do")
- # use_cases: List[_UseCase] = Field(default=[
- # {
- # "description": "Describes about what the use case to do",
- # "inputs": ["input name 1", "input name 2"],
- # "outputs": ["output name 1", "output name 2"],
- # "actors": ["actor name 1", "actor name 2"],
- # "steps": ["Step 1", "Step 2"],
- # "reason": "Because ..."
- # }
- # ], description="List all use cases.")
- # relationship: List[str] = Field(default=["use case 1 ..."],
- # description="Lists all the descriptions of relationship among these use cases")
-
- # rsp = await ActionNode.from_pydantic(_UseCaseList).fill(context=prompt, llm=self.llm)
-
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
@@ -327,7 +296,6 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
- await self.graph_db.save()
@retry(
wait=wait_random_exponential(min=1, max=20),
@@ -347,7 +315,6 @@ class RebuildSequenceView(Action):
use_case_markdown = await self._get_class_use_cases(ns_class_name)
if not use_case_markdown: # external class
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
- await self.graph_db.save()
return
block = f"## Use Cases\n{use_case_markdown}"
prompts_blocks.append(block)
@@ -382,7 +349,6 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
- await self.graph_db.save()
async def _get_participants(self, ns_class_name: str) -> List[str]:
"""
@@ -574,14 +540,12 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
- await self.graph_db.save()
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
)
- await self.graph_db.save()
return
participant = participants[0]
@@ -619,4 +583,31 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
)
- await self.graph_db.save()
+ await self._save_sequence_view(subject=entry.subject, content=sequence_view)
+
+ async def _save_sequence_view(self, subject: str, content: str):
+ pattern = re.compile(r"[^a-zA-Z0-9]")
+ name = re.sub(pattern, "_", subject)
+ filename = Path(name).with_suffix(".sequence_diagram.mmd")
+ await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content)
+
+ async def _search_participants(self, filename: str) -> Set:
+ content = await self._get_source_code(filename)
+
+ rsp = await self.llm.aask(
+ msg=content,
+ system_msgs=[
+ "You are a tool for listing all class names used in a source file.",
+ "Return a markdown JSON object with: "
+ '- a "class_names" key containing the list of class names used in the file; '
+ '- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.',
+ ],
+ )
+
+ class _Data(BaseModel):
+ class_names: List[str]
+ reasons: List
+
+ json_blocks = parse_json_code_block(rsp)
+ data = _Data.model_validate_json(json_blocks[0])
+ return set(data.class_names)
diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py
index ce8d8a967..2a99a8d99 100644
--- a/metagpt/actions/research.py
+++ b/metagpt/actions/research.py
@@ -134,7 +134,7 @@ class CollectLinks(Action):
break
model_name = config.llm.model
- prompt = reduce_message_length(gen_msg(), model_name, system_text, 4096)
+ prompt = reduce_message_length(gen_msg(), model_name, system_text, config.llm.max_token)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:
diff --git a/metagpt/config2.py b/metagpt/config2.py
index bc6af18c6..f3273419f 100644
--- a/metagpt/config2.py
+++ b/metagpt/config2.py
@@ -92,7 +92,7 @@ class Config(CLIParams, YamlModel):
"""
default_config_paths: List[Path] = [
METAGPT_ROOT / "config/config2.yaml",
- Path.home() / ".metagpt/config2.yaml",
+ CONFIG_ROOT / "config2.yaml",
]
dicts = [dict(os.environ)]
@@ -100,6 +100,20 @@ class Config(CLIParams, YamlModel):
final = merge_dict(dicts)
return Config(**final)
+ @classmethod
+ def from_llm_config(cls, llm_config: dict):
+ """user config llm
+ example:
+ llm_config = {"api_type": "xxx", "api_key": "xxx", "model": "xxx"}
+ gpt4 = Config.from_llm_config(llm_config)
+ A = Role(name="A", profile="Democratic candidate", goal="Win the election", actions=[a1], watch=[a2], config=gpt4)
+ """
+ llm_config = LLMConfig.model_validate(llm_config)
+ dicts = [dict(os.environ)]
+ dicts += [{"llm": llm_config}]
+ final = merge_dict(dicts)
+ return Config(**final)
+
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""
diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py
index 77bfc8466..af8f56372 100644
--- a/metagpt/configs/llm_config.py
+++ b/metagpt/configs/llm_config.py
@@ -10,6 +10,7 @@ from typing import Optional
from pydantic import field_validator
+from metagpt.const import LLM_API_TIMEOUT
from metagpt.utils.yaml_model import YamlModel
@@ -29,6 +30,7 @@ class LLMType(Enum):
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
MOONSHOT = "moonshot"
MISTRAL = "mistral"
+ YI = "yi" # lingyiwanwu
def __missing__(self, key):
return self.OPENAI
@@ -73,7 +75,7 @@ class LLMConfig(YamlModel):
stream: bool = False
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
top_logprobs: Optional[int] = None
- timeout: int = 60
+ timeout: int = 600
# For Network
proxy: Optional[str] = None
@@ -87,3 +89,8 @@ class LLMConfig(YamlModel):
if v in ["", None, "YOUR_API_KEY"]:
raise ValueError("Please set your API key in config2.yaml")
return v
+
+ @field_validator("timeout")
+ @classmethod
+ def check_timeout(cls, v):
+ return v or LLM_API_TIMEOUT
diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py
index af928b02a..e28b14c99 100644
--- a/metagpt/configs/search_config.py
+++ b/metagpt/configs/search_config.py
@@ -7,6 +7,8 @@
"""
from typing import Callable, Optional
+from pydantic import Field
+
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
@@ -18,3 +20,11 @@ class SearchConfig(YamlModel):
api_key: str = ""
cse_id: str = "" # for google
search_func: Optional[Callable] = None
+ params: dict = Field(
+ default_factory=lambda: {
+ "engine": "google",
+ "google_domain": "google.com",
+ "gl": "us",
+ "hl": "en",
+ }
+ )
diff --git a/metagpt/const.py b/metagpt/const.py
index 6dbbfe0c1..e4cebfd96 100644
--- a/metagpt/const.py
+++ b/metagpt/const.py
@@ -123,7 +123,6 @@ BASE64_FORMAT = "base64"
# REDIS
REDIS_KEY = "REDIS_KEY"
-LLM_API_TIMEOUT = 300
# Message id
IGNORED_MESSAGE_ID = "0"
@@ -132,3 +131,7 @@ IGNORED_MESSAGE_ID = "0"
GENERALIZATION = "Generalize"
COMPOSITION = "Composite"
AGGREGATION = "Aggregate"
+
+# Timeout
+USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration.
+LLM_API_TIMEOUT = 300
diff --git a/metagpt/environment/README.md b/metagpt/environment/README.md
index 9476ac75a..ff56b73f1 100644
--- a/metagpt/environment/README.md
+++ b/metagpt/environment/README.md
@@ -34,5 +34,5 @@ # do a `tap` action on the screen
## TODO
- add android app operation assistant under `examples/android_assistant`
- migrate roles/actions of werewolf game from old version into current version
-- migrate roles/actions of mincraft game from old version into current version
+- migrate roles/actions of minecraft game from old version into current version
- migrate roles/actions of stanford_town game from old version into current version
diff --git a/metagpt/environment/__init__.py b/metagpt/environment/__init__.py
index 692672fa7..28981f2f8 100644
--- a/metagpt/environment/__init__.py
+++ b/metagpt/environment/__init__.py
@@ -4,10 +4,9 @@
from metagpt.environment.base_env import Environment
from metagpt.environment.android_env.android_env import AndroidEnv
-from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv
from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv
from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv
from metagpt.environment.software_env.software_env import SoftwareEnv
-__all__ = ["AndroidEnv", "MincraftExtEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]
+__all__ = ["AndroidEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]
diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py
index 0e583ffb3..942bf2409 100644
--- a/metagpt/environment/base_env.py
+++ b/metagpt/environment/base_env.py
@@ -26,7 +26,7 @@ class EnvType(Enum):
ANDROID = "Android"
GYM = "Gym"
WEREWOLF = "Werewolf"
- MINCRAFT = "Mincraft"
+ MINECRAFT = "Minecraft"
STANFORDTOWN = "StanfordTown"
@@ -47,7 +47,7 @@ def mark_as_writeable(func):
class ExtEnv(BaseModel):
- """External Env to intergate actual game environment"""
+ """External Env to integrate actual game environment"""
def _check_api_exist(self, rw_api: Optional[str] = None):
if not rw_api:
@@ -129,8 +129,8 @@ class Environment(ExtEnv):
self.roles[role.profile] = role
for role in roles: # setup system message with roles
- role.set_env(self)
role.context = self.context
+ role.set_env(self)
def publish_message(self, message: Message, peekable: bool = True) -> bool:
"""
diff --git a/metagpt/environment/mincraft_env/__init__.py b/metagpt/environment/minecraft_env/__init__.py
similarity index 100%
rename from metagpt/environment/mincraft_env/__init__.py
rename to metagpt/environment/minecraft_env/__init__.py
diff --git a/metagpt/environment/mincraft_env/const.py b/metagpt/environment/minecraft_env/const.py
similarity index 92%
rename from metagpt/environment/mincraft_env/const.py
rename to metagpt/environment/minecraft_env/const.py
index a7222f9cd..8ac15decc 100644
--- a/metagpt/environment/mincraft_env/const.py
+++ b/metagpt/environment/minecraft_env/const.py
@@ -4,8 +4,8 @@
from metagpt.const import METAGPT_ROOT
-# For Mincraft Game Agent
-MC_CKPT_DIR = METAGPT_ROOT / "data/mincraft/ckpt"
+# For Minecraft Game Agent
+MC_CKPT_DIR = METAGPT_ROOT / "data/minecraft/ckpt"
MC_LOG_DIR = METAGPT_ROOT / "logs"
MC_DEFAULT_WARMUP = {
"context": 15,
diff --git a/metagpt/environment/mincraft_env/mincraft_env.py b/metagpt/environment/minecraft_env/minecraft_env.py
similarity index 98%
rename from metagpt/environment/mincraft_env/mincraft_env.py
rename to metagpt/environment/minecraft_env/minecraft_env.py
index 6aaa34187..26d4d03a8 100644
--- a/metagpt/environment/mincraft_env/mincraft_env.py
+++ b/metagpt/environment/minecraft_env/minecraft_env.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : MG Mincraft Env
+# @Desc : MG Minecraft Env
# refs to `voyager voyager.py`
import json
@@ -12,15 +12,15 @@ from pydantic import ConfigDict, Field
from metagpt.config2 import config as CONFIG
from metagpt.environment.base_env import Environment
-from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
-from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
+from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
+from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
from metagpt.logs import logger
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
-class MincraftEnv(Environment, MincraftExtEnv):
- """MincraftEnv, including shared memory of cache and infomation between roles"""
+class MinecraftEnv(Environment, MinecraftExtEnv):
+ """MinecraftEnv, including shared memory of cache and information between roles"""
model_config = ConfigDict(arbitrary_types_allowed=True)
diff --git a/metagpt/environment/mincraft_env/mincraft_ext_env.py b/metagpt/environment/minecraft_env/minecraft_ext_env.py
similarity index 95%
rename from metagpt/environment/mincraft_env/mincraft_ext_env.py
rename to metagpt/environment/minecraft_env/minecraft_ext_env.py
index b86250d8c..3b793079f 100644
--- a/metagpt/environment/mincraft_env/mincraft_ext_env.py
+++ b/metagpt/environment/minecraft_env/minecraft_ext_env.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : The Mincraft external environment to integrate with Mincraft game
+# @Desc : The Minecraft external environment to integrate with Minecraft game
# refs to `voyager bridge.py`
import json
@@ -11,18 +11,18 @@ import requests
from pydantic import ConfigDict, Field, model_validator
from metagpt.environment.base_env import ExtEnv, mark_as_writeable
-from metagpt.environment.mincraft_env.const import (
+from metagpt.environment.minecraft_env.const import (
MC_CKPT_DIR,
MC_CORE_INVENTORY_ITEMS,
MC_CURRICULUM_OB,
MC_DEFAULT_WARMUP,
METAGPT_ROOT,
)
-from metagpt.environment.mincraft_env.process_monitor import SubprocessMonitor
+from metagpt.environment.minecraft_env.process_monitor import SubprocessMonitor
from metagpt.logs import logger
-class MincraftExtEnv(ExtEnv):
+class MinecraftExtEnv(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
mc_port: Optional[int] = Field(default=None)
@@ -48,7 +48,7 @@ class MincraftExtEnv(ExtEnv):
self.mineflayer = SubprocessMonitor(
commands=[
"node",
- METAGPT_ROOT.joinpath("metagpt", "environment", "mincraft_env", "mineflayer", "index.js"),
+ METAGPT_ROOT.joinpath("metagpt", "environment", "minecraft_env", "mineflayer", "index.js"),
str(self.server_port),
],
name="mineflayer",
diff --git a/metagpt/environment/mincraft_env/mineflayer/.gitignore b/metagpt/environment/minecraft_env/mineflayer/.gitignore
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/.gitignore
rename to metagpt/environment/minecraft_env/mineflayer/.gitignore
diff --git a/metagpt/environment/mincraft_env/mineflayer/.prettierignore b/metagpt/environment/minecraft_env/mineflayer/.prettierignore
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/.prettierignore
rename to metagpt/environment/minecraft_env/mineflayer/.prettierignore
diff --git a/metagpt/environment/mincraft_env/mineflayer/.prettierrc.json b/metagpt/environment/minecraft_env/mineflayer/.prettierrc.json
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/.prettierrc.json
rename to metagpt/environment/minecraft_env/mineflayer/.prettierrc.json
diff --git a/metagpt/environment/mincraft_env/mineflayer/index.js b/metagpt/environment/minecraft_env/mineflayer/index.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/index.js
rename to metagpt/environment/minecraft_env/mineflayer/index.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/base.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/base.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/base.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/base.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/chests.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/chests.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/chests.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/chests.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/inventory.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/inventory.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/inventory.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/inventory.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/onChat.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/onChat.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/onChat.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/onChat.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/onError.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/onError.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/onError.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/onError.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/onSave.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/onSave.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/onSave.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/onSave.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/status.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/status.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/status.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/status.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/voxels.js b/metagpt/environment/minecraft_env/mineflayer/lib/observation/voxels.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/observation/voxels.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/observation/voxels.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/skillLoader.js b/metagpt/environment/minecraft_env/mineflayer/lib/skillLoader.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/skillLoader.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/skillLoader.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/utils.js b/metagpt/environment/minecraft_env/mineflayer/lib/utils.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/lib/utils.js
rename to metagpt/environment/minecraft_env/mineflayer/lib/utils.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/.gitignore b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/.gitignore
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/.gitignore
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/.gitignore
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/LICENSE b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/LICENSE
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/LICENSE
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/LICENSE
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/README.md b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/README.md
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/README.md
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/README.md
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/_config.yml b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/_config.yml
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/_config.yml
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/_config.yml
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/docs/api.md b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/docs/api.md
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/docs/api.md
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/docs/api.md
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/collector.js b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/examples/collector.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/collector.js
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/examples/collector.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/oreMiner.js b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/examples/oreMiner.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/oreMiner.js
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/examples/oreMiner.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/storageBot.js b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/examples/storageBot.js
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/storageBot.js
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/examples/storageBot.js
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/package.json b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/package.json
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/package.json
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/package.json
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/BlockVeins.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/BlockVeins.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/BlockVeins.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/BlockVeins.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/CollectBlock.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/CollectBlock.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/CollectBlock.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/CollectBlock.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Inventory.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/Inventory.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Inventory.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/Inventory.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Targets.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/Targets.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Targets.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/Targets.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TaskQueue.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/TaskQueue.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TaskQueue.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/TaskQueue.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Util.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/Util.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Util.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/Util.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/index.ts b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/index.ts
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/index.ts
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/src/index.ts
diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/tsconfig.json b/metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/tsconfig.json
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/tsconfig.json
rename to metagpt/environment/minecraft_env/mineflayer/mineflayer-collectblock/tsconfig.json
diff --git a/metagpt/environment/mincraft_env/mineflayer/package.json b/metagpt/environment/minecraft_env/mineflayer/package.json
similarity index 100%
rename from metagpt/environment/mincraft_env/mineflayer/package.json
rename to metagpt/environment/minecraft_env/mineflayer/package.json
diff --git a/metagpt/environment/mincraft_env/process_monitor.py b/metagpt/environment/minecraft_env/process_monitor.py
similarity index 100%
rename from metagpt/environment/mincraft_env/process_monitor.py
rename to metagpt/environment/minecraft_env/process_monitor.py
diff --git a/metagpt/learn/skill_loader.py b/metagpt/learn/skill_loader.py
index bcf28bb87..e98f73cf9 100644
--- a/metagpt/learn/skill_loader.py
+++ b/metagpt/learn/skill_loader.py
@@ -9,11 +9,11 @@
from pathlib import Path
from typing import Dict, List, Optional
-import aiofiles
import yaml
from pydantic import BaseModel, Field
from metagpt.context import Context
+from metagpt.utils.common import aread
class Example(BaseModel):
@@ -68,8 +68,7 @@ class SkillsDeclaration(BaseModel):
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
if not skill_yaml_file_name:
skill_yaml_file_name = Path(__file__).parent.parent.parent / "docs/.well-known/skills.yaml"
- async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
- data = await reader.read(-1)
+ data = await aread(filename=skill_yaml_file_name)
skill_data = yaml.safe_load(data)
return SkillsDeclaration(**skill_data)
diff --git a/metagpt/prompts/di/ml_action.py b/metagpt/prompts/di/ml_action.py
deleted file mode 100644
index 5d27c7ff0..000000000
--- a/metagpt/prompts/di/ml_action.py
+++ /dev/null
@@ -1,128 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# @Time : 2023/11/24 15:43
-# @Author : lidanyang
-# @File : ml_action
-# @Desc :
-UPDATE_DATA_COLUMNS = """
-# Background
-Keep dataset column information updated before model train.
-## Done Tasks
-```python
-{history_code}
-```end
-
-# Task
-Update and print the dataset's column information only if the train or test data has changed. Use the following code:
-```python
-from metagpt.tools.libs.data_preprocess import get_column_info
-
-column_info = get_column_info(df)
-print("column_info")
-print(column_info)
-```end
-
-# Constraints:
-- Use the DataFrame variable from 'Done Tasks' in place of df.
-- Import `get_column_info` only if it's not already imported.
-"""
-
-PRINT_DATA_COLUMNS = {
- "name": "print_column_info",
- "description": "Print the latest column information after 'Done Tasks' code if first read or data changed.",
- "parameters": {
- "type": "object",
- "properties": {
- "code": {
- "type": "string",
- "description": "The code to be added to a new cell in jupyter.",
- },
- },
- "required": ["code"],
- },
-}
-
-ML_COMMON_PROMPT = """
-# Background
-As a data scientist, you need to help user to achieve their goal [{user_requirement}] step-by-step in an continuous Jupyter notebook.
-
-## Done Tasks
-```python
-{history_code}
-```end
-
-## Current Task
-{current_task}
-
-# Latest Data Info
-Latest data info after previous tasks:
-{column_info}
-
-# Task
-Write complete code for 'Current Task'. And avoid duplicating code from 'Done Tasks', such as repeated import of packages, reading data, etc.
-Specifically, {tool_type_usage_prompt}
-"""
-
-USE_NO_TOOLS_EXAMPLE = """
-# Output Example:
-when current task is "train a lightgbm model on training data", the code can be like:
-```python
-# Step 1: check data type and convert to numeric
-obj_cols = train.select_dtypes(include='object').columns.tolist()
-
-for col in obj_cols:
- encoder = LabelEncoder()
- train[col] = encoder.fit_transform(train[col].unique().tolist() + ['unknown'])
- test[col] = test[col].apply(lambda x: x if x in encoder.classes_ else 'unknown')
- test[col] = encoder.transform(test[col])
-
-# Step 2: train lightgbm model
-model = LGBMClassifier()
-model.fit(train, y_train)
-```end
-
-# Constraints:
-- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
-"""
-
-USE_TOOLS_EXAMPLE = """
-# Capabilities
-- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class.
-- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
-
-# Available Tools:
-Each Class tool is described in JSON format. When you call a tool, import the tool from its path first.
-{tool_schemas}
-
-# Output Example:
-when current task is "do data preprocess, like fill missing value, handle outliers, etc.", the code can be like:
-```python
-# Step 1: fill missing value
-# Tools used: ['FillMissingValue']
-from metagpt.tools.libs.data_preprocess import FillMissingValue
-
-train_processed = train.copy()
-test_processed = test.copy()
-num_cols = train_processed.select_dtypes(include='number').columns.tolist()
-if 'label' in num_cols:
- num_cols.remove('label')
-fill_missing_value = FillMissingValue(features=num_cols, strategy='mean')
-fill_missing_value.fit(train_processed)
-train_processed = fill_missing_value.transform(train_processed)
-test_processed = fill_missing_value.transform(test_processed)
-
-# Step 2: handle outliers
-for col in num_cols:
- low, high = train_processed[col].quantile([0.01, 0.99])
- train_processed[col] = train_processed[col].clip(low, high)
- test_processed[col] = test_processed[col].clip(low, high)
-```end
-
-# Constraints:
-- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
-- Always prioritize using pre-defined tools for the same functionality.
-- Always copy the DataFrame before processing it and use the copy to process.
-"""
-
-ML_GENERATE_CODE_PROMPT = ML_COMMON_PROMPT + USE_NO_TOOLS_EXAMPLE
-ML_TOOL_USAGE_PROMPT = ML_COMMON_PROMPT + USE_TOOLS_EXAMPLE
diff --git a/metagpt/prompts/di/write_analysis_code.py b/metagpt/prompts/di/write_analysis_code.py
index 4eccefcd1..e5663d498 100644
--- a/metagpt/prompts/di/write_analysis_code.py
+++ b/metagpt/prompts/di/write_analysis_code.py
@@ -1,93 +1,112 @@
-ASSIGN_TASK_TYPE_PROMPT = """
-Please assign a task type to each task in the list below from the given categories:
-{task_info}
+INTERPRETER_SYSTEM_MSG = """As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function."""
-## All Task Type:
-{task_type_desc}
+STRUCTUAL_PROMPT = """
+# User Requirement
+{user_requirement}
+
+# Plan Status
+{plan_status}
+
+# Tool Info
+{tool_info}
+
+# Constraints
+- Take on Current Task if it is in Plan Status, otherwise, tackle User Requirement directly.
+- Ensure the output new code is executable in the same Jupyter notebook as the previous executed code.
+- Always prioritize using pre-defined tools for the same functionality.
+
+# Output
+While some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. Output code in the following format:
+```python
+your code
+```
"""
-ASSIGN_TASK_TYPE_CONFIG = {
- "name": "assign_task_type",
- "description": "Assign task type to each task by order.",
- "parameters": {
- "type": "object",
- "properties": {
- "task_type": {
- "type": "array",
- "description": "List of task type. The length should as long as task list",
- "items": {
- "type": "string",
- },
- },
- },
- "required": ["task_type"],
- },
-}
+REFLECTION_SYSTEM_MSG = """You are an AI Python assistant. You will be given your previous implementation code of a task, runtime error results, and a hint to change the implementation appropriately. Write your full implementation."""
-TOOL_RECOMMENDATION_PROMPT = """
-## User Requirement:
-{current_task}
+DEBUG_REFLECTION_EXAMPLE = '''
+[previous impl]:
+assistant:
+```python
+def add(a: int, b: int) -> int:
+ """
+ Given integers a and b, return the total value of a and b.
+ """
+ return a - b
+```
-## Task
-Recommend up to five tools from 'Available Tools' that can help solve the 'User Requirement'.
+user:
+Tests failed:
+assert add(1, 2) == 3 # output: -1
+assert add(1, 2) == 4 # output: -1
-## Available Tools:
-{available_tools}
+[reflection on previous impl]:
+The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input.
-## Tool Selection and Instructions:
-- Select tools most relevant to completing the 'User Requirement'.
-- If you believe that no tools are suitable, indicate with an empty list.
-- Only list the names of the tools, not the full schema of each tool.
-- Ensure selected tools are listed in 'Available Tools'.
+[improved impl]:
+def add(a: int, b: int) -> int:
+ """
+ Given integers a and b, return the total value of a and b.
+ """
+ return a + b
+'''
+
+REFLECTION_PROMPT = """
+[example]
+Here is an example of debugging with reflection.
+{debug_example}
+[/example]
+
+[context]
+{context}
+
+[previous impl]:
+{previous_impl}
+
+[instruction]
+Analyze your previous code and error in [context] step by step, provide me with improved method and code. Remember to follow [context] requirement. Don't forget to write code for steps behind the error step.
+Output a json following the format:
+```json
+{{
+ "reflection": str = "Reflection on previous implementation",
+ "improved_impl": str = "Refined code after reflection.",
+}}
+```
"""
-SELECT_FUNCTION_TOOLS = {
- "name": "select_function_tools",
- "description": "For current task, select suitable tools for it.",
- "parameters": {
- "type": "object",
- "properties": {
- "recommend_tools": {
- "type": "array",
- "description": "List of tool names. Empty list if no tool is suitable.",
- "items": {
- "type": "string",
- },
- },
- },
- "required": ["recommend_tools"],
- },
-}
+CHECK_DATA_PROMPT = """
+# Background
+Check latest data info to guide subsequent tasks.
-CODE_GENERATOR_WITH_TOOLS = {
- "name": "add_subtask_code",
- "description": "Add new code cell of current task to the end of an active Jupyter notebook.",
- "parameters": {
- "type": "object",
- "properties": {
- "code": {
- "type": "string",
- "description": "The code to be added to a new cell in jupyter.",
- },
- },
- "required": ["code"],
- },
-}
+## Finished Tasks
+```python
+{code_written}
+```end
-TOOL_USAGE_PROMPT = """
-# Instruction
-Write complete code for 'Current Task'. And avoid duplicating code from finished tasks, such as repeated import of packages, reading data, etc.
-Specifically, {tool_type_usage_prompt}
+# Task
+Check code in finished tasks, print key variables to guide your following actions.
+Specifically, if it is a data analysis or machine learning task, print the the latest column information using the following code, with DataFrame variable from 'Finished Tasks' in place of df:
+```python
+from metagpt.tools.libs.data_preprocess import get_column_info
-# Capabilities
-- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class.
-- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
-
-# Available Tools (can be empty):
-Each Class tool is described in JSON format. When you call a tool, import the tool first.
-{tool_schemas}
+column_info = get_column_info(df)
+print("column_info")
+print(column_info)
+```end
+Otherwise, print out any key variables you see fit. Return an empty string if you think there is no important data to check.
# Constraints:
-- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
-- Always prioritize using pre-defined tools for the same functionality.
+- Your code is to be added to a new cell in jupyter.
+
+# Instruction
+Output code following the format:
+```python
+your code
+```
+"""
+
+DATA_INFO = """
+# Latest Data Info
+Latest data info after previous tasks:
+{info}
"""
diff --git a/metagpt/prompts/tool_types.py b/metagpt/prompts/task_type.py
similarity index 87%
rename from metagpt/prompts/tool_types.py
rename to metagpt/prompts/task_type.py
index 8728d22f3..5b1ffc744 100644
--- a/metagpt/prompts/tool_types.py
+++ b/metagpt/prompts/task_type.py
@@ -1,11 +1,11 @@
-# Prompt for using tools of "eda" type
+# Prompt for taking on "eda" tasks
EDA_PROMPT = """
The current task is about exploratory data analysis, please note the following:
- Distinguish column types with `select_dtypes` for tailored analysis and visualization, such as correlation.
- Remember to `import numpy as np` before using Numpy functions.
"""
-# Prompt for using tools of "data_preprocess" type
+# Prompt for taking on "data_preprocess" tasks
DATA_PREPROCESS_PROMPT = """
The current task is about data preprocessing, please note the following:
- Monitor data types per column, applying appropriate methods.
@@ -15,9 +15,10 @@ The current task is about data preprocessing, please note the following:
- Prefer alternatives to one-hot encoding for categorical data.
- Only encode or scale necessary columns to allow for potential feature-specific engineering tasks (like time_extract, binning, extraction, etc.) later.
- Each step do data preprocessing to train, must do same for test separately at the same time.
+- Always copy the DataFrame before processing it and use the copy to process.
"""
-# Prompt for using tools of "feature_engineering" type
+# Prompt for taking on "feature_engineering" tasks
FEATURE_ENGINEERING_PROMPT = """
The current task is about feature engineering. when performing it, please adhere to the following principles:
- Generate as diverse features as possible to improve the model's performance step-by-step.
@@ -27,9 +28,10 @@ The current task is about feature engineering. when performing it, please adhere
- Each feature engineering operation performed on the train set must also applies to the test separately at the same time.
- Avoid using the label column to create features, except for cat encoding.
- Use the data from previous task result if exist, do not mock or reload data yourself.
+- Always copy the DataFrame before processing it and use the copy to process.
"""
-# Prompt for using tools of "model_train" type
+# Prompt for taking on "model_train" tasks
MODEL_TRAIN_PROMPT = """
The current task is about training a model, please ensure high performance:
- Keep in mind that your user prioritizes results and is highly focused on model performance. So, when needed, feel free to use models of any complexity to improve effectiveness, such as XGBoost, CatBoost, etc.
@@ -38,14 +40,14 @@ The current task is about training a model, please ensure high performance:
- Set suitable hyperparameters for the model, make metrics as high as possible.
"""
-# Prompt for using tools of "model_evaluate" type
+# Prompt for taking on "model_evaluate" tasks
MODEL_EVALUATE_PROMPT = """
The current task is about evaluating a model, please note the following:
- Ensure that the evaluated data is same processed as the training data. If not, remember use object in 'Done Tasks' to transform the data.
- Use trained model from previous task result directly, do not mock or reload model yourself.
"""
-# Prompt for using tools of "vision" type
+# Prompt for taking on "image2webpage" tasks
IMAGE2WEBPAGE_PROMPT = """
The current task is about converting image into webpage code. please note the following:
- Single-Step Code Generation: Execute the entire code generation process in a single step, encompassing HTML, CSS, and JavaScript. Avoid fragmenting the code generation into multiple separate steps to maintain consistency and simplify the development workflow.
diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py
index 872f9b2c7..1aeacbe83 100644
--- a/metagpt/provider/anthropic_api.py
+++ b/metagpt/provider/anthropic_api.py
@@ -5,6 +5,7 @@ from anthropic import AsyncAnthropic
from anthropic.types import Message, Usage
from metagpt.configs.llm_config import LLMConfig, LLMType
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@@ -41,15 +42,15 @@ class AnthropicLLM(BaseLLM):
def get_choice_text(self, resp: Message) -> str:
return resp.content[0].text
- async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
+ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
self._update_costs(resp.usage, self.model)
return resp
- async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
- return await self._achat_completion(messages, timeout=timeout)
+ async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = Usage(input_tokens=0, output_tokens=0)
diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py
index 33331ae94..9aeeda00c 100644
--- a/metagpt/provider/azure_openai_api.py
+++ b/metagpt/provider/azure_openai_api.py
@@ -25,7 +25,7 @@ class AzureOpenAILLM(OpenAILLM):
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
- self.pricing_plan = self.config.pricing_plan
+ self.pricing_plan = self.config.pricing_plan or self.model
def _make_client_kwargs(self) -> dict:
kwargs = dict(
diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py
index 9850f4cf5..db2757ec3 100644
--- a/metagpt/provider/base_llm.py
+++ b/metagpt/provider/base_llm.py
@@ -10,10 +10,9 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
-from typing import Dict, Optional, Union
+from typing import Optional, Union
from openai import AsyncOpenAI
-from openai.types import CompletionUsage
from pydantic import BaseModel
from tenacity import (
after_log,
@@ -24,11 +23,11 @@ from tenacity import (
)
from metagpt.configs.llm_config import LLMConfig
+from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
-from metagpt.utils.exceptions import handle_exception
class BaseLLM(ABC):
@@ -41,7 +40,7 @@ class BaseLLM(ABC):
# OpenAI / Azure / Others
aclient: Optional[Union[AsyncOpenAI]] = None
cost_manager: Optional[CostManager] = None
- model: Optional[str] = None
+ model: Optional[str] = None # deprecated
pricing_plan: Optional[str] = None
@abstractmethod
@@ -75,6 +74,28 @@ class BaseLLM(ABC):
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "system", "content": msg}
+ def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
+ """convert messages to list[dict]."""
+ from metagpt.schema import Message
+
+ if not isinstance(messages, list):
+ messages = [messages]
+
+ processed_messages = []
+ for msg in messages:
+ if isinstance(msg, str):
+ processed_messages.append({"role": "user", "content": msg})
+ elif isinstance(msg, dict):
+ assert set(msg.keys()) == set(["role", "content"])
+ processed_messages.append(msg)
+ elif isinstance(msg, Message):
+ processed_messages.append(msg.to_dict())
+ else:
+ raise ValueError(
+ f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
+ )
+ return processed_messages
+
def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]:
return [self._system_msg(msg) for msg in msgs]
@@ -88,6 +109,7 @@ class BaseLLM(ABC):
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
"""
calc_usage = self.config.calc_usage and local_calc_usage
+ model = model or self.pricing_plan
model = model or self.model
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
if calc_usage and self.cost_manager:
@@ -105,11 +127,11 @@ class BaseLLM(ABC):
async def aask(
self,
- msg: str,
+ msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
- timeout=3,
+ timeout=USE_CONFIG_TIMEOUT,
stream=True,
) -> str:
if system_msgs:
@@ -120,34 +142,36 @@ class BaseLLM(ABC):
message = []
if format_msgs:
message.extend(format_msgs)
- message.append(self._user_msg(msg, images=images))
+ if isinstance(msg, str):
+ message.append(self._user_msg(msg, images=images))
+ else:
+ message.extend(msg)
logger.debug(message)
- rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
+ rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp
def _extract_assistant_rsp(self, context):
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
- async def aask_batch(self, msgs: list, timeout=3) -> str:
+ async def aask_batch(self, msgs: list, timeout=USE_CONFIG_TIMEOUT) -> str:
"""Sequential questioning"""
context = []
for msg in msgs:
umsg = self._user_msg(msg)
context.append(umsg)
- rsp_text = await self.acompletion_text(context, timeout=timeout)
+ rsp_text = await self.acompletion_text(context, timeout=self.get_timeout(timeout))
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
- async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3) -> dict:
- """FIXME: No code segment filtering has been done here, and all results are actually displayed"""
+ async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict:
raise NotImplementedError
@abstractmethod
- async def _achat_completion(self, messages: list[dict], timeout=3):
+ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""_achat_completion implemented by inherited class"""
@abstractmethod
- async def acompletion(self, messages: list[dict], timeout=3):
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""Asynchronous version of completion
All GPTAPIs are required to provide the standard OpenAI completion interface
[
@@ -158,7 +182,7 @@ class BaseLLM(ABC):
"""
@abstractmethod
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
"""_achat_completion_stream implemented by inherited class"""
@retry(
@@ -168,11 +192,13 @@ class BaseLLM(ABC):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
- async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
+ async def acompletion_text(
+ self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT
+ ) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
if stream:
- return await self._achat_completion_stream(messages, timeout=timeout)
- resp = await self._achat_completion(messages, timeout=timeout)
+ return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout))
+ resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(resp)
def get_choice_text(self, rsp: dict) -> str:
@@ -223,20 +249,6 @@ class BaseLLM(ABC):
"""
return json.loads(self.get_choice_function(rsp)["arguments"], strict=False)
- @handle_exception
- def _update_costs(self, usage: CompletionUsage | Dict):
- """
- Updates the costs based on the provided usage information.
- """
- if self.config.calc_usage and usage and self.cost_manager:
- if isinstance(usage, Dict):
- prompt_tokens = int(usage.get("prompt_tokens", 0))
- completion_tokens = int(usage.get("completion_tokens", 0))
- else:
- prompt_tokens = usage.prompt_tokens
- completion_tokens = usage.completion_tokens
- self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
-
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
@@ -244,3 +256,11 @@ class BaseLLM(ABC):
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""
return [i.to_dict() for i in messages]
+
+ def with_model(self, model: str):
+ """Set model and return self. For example, `with_model("gpt-3.5-turbo")`."""
+ self.config.model = model
+ return self
+
+ def get_timeout(self, timeout: int) -> int:
+ return timeout or self.config.timeout or LLM_API_TIMEOUT
diff --git a/metagpt/provider/constant.py b/metagpt/provider/constant.py
index db67847a8..dee78dc3b 100644
--- a/metagpt/provider/constant.py
+++ b/metagpt/provider/constant.py
@@ -25,6 +25,7 @@ GENERAL_FUNCTION_SCHEMA = {
},
}
+
# tool_choice value for general_function_schema
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
GENERAL_TOOL_CHOICE = {"type": "function", "function": {"name": "execute"}}
diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py
index 21f3ef351..82224e893 100644
--- a/metagpt/provider/dashscope_api.py
+++ b/metagpt/provider/dashscope_api.py
@@ -25,6 +25,7 @@ from dashscope.common.error import (
UnsupportedApiProtocol,
)
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM, LLMConfig
from metagpt.provider.llm_provider_registry import LLMType, register_provider
@@ -202,16 +203,16 @@ class DashScopeLLM(BaseLLM):
self._update_costs(dict(resp.usage))
return resp.output
- async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
+ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> GenerationOutput:
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
- async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
- return await self._achat_completion(messages, timeout=timeout)
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> GenerationOutput:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py
index 1b9149396..8e5da8f16 100644
--- a/metagpt/provider/general_api_base.py
+++ b/metagpt/provider/general_api_base.py
@@ -573,7 +573,7 @@ class APIRequestor:
total=request_timeout[1],
)
else:
- timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS)
+ timeout = aiohttp.ClientTimeout(total=request_timeout or TIMEOUT_SECS)
if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py
index 09e554205..e4b3a3f17 100644
--- a/metagpt/provider/google_gemini_api.py
+++ b/metagpt/provider/google_gemini_api.py
@@ -1,8 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
-
-from typing import Optional, Union
+import json
+import os
+from dataclasses import asdict
+from typing import List, Optional, Union
import google.generativeai as genai
from google.ai import generativelanguage as glm
@@ -10,14 +12,17 @@ from google.generativeai.generative_models import GenerativeModel
from google.generativeai.types import content_types
from google.generativeai.types.generation_types import (
AsyncGenerateContentResponse,
+ BlockedPromptException,
GenerateContentResponse,
GenerationConfig,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
-from metagpt.logs import log_llm_stream
+from metagpt.const import USE_CONFIG_TIMEOUT
+from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
+from metagpt.schema import Message
class GeminiGenerativeModel(GenerativeModel):
@@ -51,6 +56,10 @@ class GeminiLLM(BaseLLM):
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: LLMConfig):
+ if config.proxy:
+ logger.info(f"Use proxy: {config.proxy}")
+ os.environ["http_proxy"] = config.proxy
+ os.environ["https_proxy"] = config.proxy
genai.configure(api_key=config.api_key)
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]:
@@ -61,6 +70,35 @@ class GeminiLLM(BaseLLM):
def _assistant_msg(self, msg: str) -> dict[str, str]:
return {"role": "model", "parts": [msg]}
+ def _system_msg(self, msg: str) -> dict[str, str]:
+ return {"role": "user", "parts": [msg]}
+
+ def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
+ """convert messages to list[dict]."""
+ from metagpt.schema import Message
+
+ if not isinstance(messages, list):
+ messages = [messages]
+
+ # REF: https://ai.google.dev/tutorials/python_quickstart
+ # As a dictionary, the message requires `role` and `parts` keys.
+ # The role in a conversation can either be the `user`, which provides the prompts,
+ # or `model`, which provides the responses.
+ processed_messages = []
+ for msg in messages:
+ if isinstance(msg, str):
+ processed_messages.append({"role": "user", "parts": [msg]})
+ elif isinstance(msg, dict):
+ assert set(msg.keys()) == set(["role", "parts"])
+ processed_messages.append(msg)
+ elif isinstance(msg, Message):
+ processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]})
+ else:
+ raise ValueError(
+ f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
+ )
+ return processed_messages
+
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
@@ -88,22 +126,28 @@ class GeminiLLM(BaseLLM):
self._update_costs(usage)
return resp
- async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
+ async def _achat_completion(
+ self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT
+ ) -> "AsyncGenerateContentResponse":
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
usage = await self.aget_usage(messages, resp.text)
self._update_costs(usage)
return resp
- async def acompletion(self, messages: list[dict], timeout=3) -> dict:
- return await self._achat_completion(messages, timeout=timeout)
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
**self._const_kwargs(messages, stream=True)
)
collected_content = []
async for chunk in resp:
- content = chunk.text
+ try:
+ content = chunk.text
+ except Exception as e:
+ logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}")
+ raise BlockedPromptException(str(chunk))
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
@@ -112,3 +156,10 @@ class GeminiLLM(BaseLLM):
usage = await self.aget_usage(messages, full_content)
self._update_costs(usage)
return full_content
+
+ def list_models(self) -> List:
+ models = []
+ for model in genai.list_models(page_size=100):
+ models.append(asdict(model))
+ logger.info(json.dumps(models))
+ return models
diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py
index e5f37c5b9..87dbd105f 100644
--- a/metagpt/provider/human_provider.py
+++ b/metagpt/provider/human_provider.py
@@ -6,6 +6,7 @@ Author: garylin2099
from typing import Optional
from metagpt.configs.llm_config import LLMConfig
+from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
@@ -16,9 +17,9 @@ class HumanProvider(BaseLLM):
"""
def __init__(self, config: LLMConfig):
- pass
+ self.config = config
- def ask(self, msg: str, timeout=3) -> str:
+ def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str:
logger.info("It's your turn, please type in your response. You may also refer to the context below")
rsp = input(msg)
if rsp in ["exit", "quit"]:
@@ -31,20 +32,23 @@ class HumanProvider(BaseLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
generator: bool = False,
- timeout=3,
+ timeout=USE_CONFIG_TIMEOUT,
) -> str:
- return self.ask(msg, timeout=timeout)
+ return self.ask(msg, timeout=self.get_timeout(timeout))
- async def _achat_completion(self, messages: list[dict], timeout=3):
+ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
- async def acompletion(self, messages: list[dict], timeout=3):
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""dummy implementation of abstract method in base"""
return []
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass
- async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
+ async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
"""dummy implementation of abstract method in base"""
return ""
+
+ def get_timeout(self, timeout: int) -> int:
+ return timeout or LLM_API_TIMEOUT
diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py
index f65d7e411..2913eb1dd 100644
--- a/metagpt/provider/ollama_api.py
+++ b/metagpt/provider/ollama_api.py
@@ -5,7 +5,7 @@
import json
from metagpt.configs.llm_config import LLMConfig, LLMType
-from metagpt.const import LLM_API_TIMEOUT
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
@@ -50,28 +50,28 @@ class OllamaLLM(BaseLLM):
chunk = chunk.decode(encoding)
return json.loads(chunk)
- async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
+ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
params=self._const_kwargs(messages),
- request_timeout=LLM_API_TIMEOUT,
+ request_timeout=self.get_timeout(timeout),
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
- async def acompletion(self, messages: list[dict], timeout=3) -> dict:
- return await self._achat_completion(messages, timeout=timeout)
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
stream=True,
params=self._const_kwargs(messages, stream=True),
- request_timeout=LLM_API_TIMEOUT,
+ request_timeout=self.get_timeout(timeout),
)
collected_content = []
diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py
index 82206f738..dbfed72df 100644
--- a/metagpt/provider/openai_api.py
+++ b/metagpt/provider/openai_api.py
@@ -25,11 +25,11 @@ from tenacity import (
)
from metagpt.configs.llm_config import LLMConfig, LLMType
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
-from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
@@ -40,7 +40,7 @@ from metagpt.utils.token_counter import (
)
-@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL])
+@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL, LLMType.YI])
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
@@ -75,15 +75,17 @@ class OpenAILLM(BaseLLM):
return params
- async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
- **self._cons_kwargs(messages, timeout=timeout), stream=True
+ **self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True
)
usage = None
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
- finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
+ finish_reason = (
+ chunk.choices[0].finish_reason if chunk.choices and hasattr(chunk.choices[0], "finish_reason") else None
+ )
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
if finish_reason:
@@ -103,7 +105,7 @@ class OpenAILLM(BaseLLM):
self._update_costs(usage)
return full_reply_content
- def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
+ def _cons_kwargs(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, **extra_kwargs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
@@ -111,20 +113,20 @@ class OpenAILLM(BaseLLM):
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": self.config.temperature,
"model": self.model,
- "timeout": max(self.config.timeout, timeout),
+ "timeout": self.get_timeout(timeout),
}
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs
- async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
- kwargs = self._cons_kwargs(messages, timeout=timeout)
+ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
+ kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
- async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
- return await self._achat_completion(messages, timeout=timeout)
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
@retry(
wait=wait_random_exponential(min=1, max=60),
@@ -133,52 +135,24 @@ class OpenAILLM(BaseLLM):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
- async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
+ async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
"""when streaming, print each token in place."""
if stream:
- await self._achat_completion_stream(messages, timeout=timeout)
+ return await self._achat_completion_stream(messages, timeout=timeout)
- rsp = await self._achat_completion(messages, timeout=timeout)
+ rsp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(rsp)
- def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
- """Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
- if "tools" not in kwargs:
- configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
- kwargs.update(configs)
-
- return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
-
- def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
- """convert messages to list[dict]."""
- # 全部转成list
- if not isinstance(messages, list):
- messages = [messages]
-
- # 转成list[dict]
- processed_messages = []
- for msg in messages:
- if isinstance(msg, str):
- processed_messages.append({"role": "user", "content": msg})
- elif isinstance(msg, dict):
- assert set(msg.keys()) == set(["role", "content"])
- processed_messages.append(msg)
- elif isinstance(msg, Message):
- processed_messages.append(msg.to_dict())
- else:
- raise ValueError(
- f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
- )
- return processed_messages
-
- async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
- messages = self._process_message(messages)
- kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
+ async def _achat_completion_function(
+ self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **chat_configs
+ ) -> ChatCompletion:
+ messages = self.format_msg(messages)
+ kwargs = self._cons_kwargs(messages=messages, timeout=self.get_timeout(timeout), **chat_configs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
- async def aask_code(self, messages: list[dict], **kwargs) -> dict:
+ async def aask_code(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
@@ -188,12 +162,15 @@ class OpenAILLM(BaseLLM):
>>> rsp = await llm.aask_code(msg)
# -> {'language': 'python', 'code': "print('Hello, World!')"}
"""
+ if "tools" not in kwargs:
+ configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
+ kwargs.update(configs)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
def _parse_arguments(self, arguments: str) -> dict:
"""parse arguments in openai function call"""
- if "langugae" not in arguments and "code" not in arguments:
+ if "language" not in arguments and "code" not in arguments:
logger.warning(f"Not found `code`, `language`, We assume it is pure code:\n {arguments}\n. ")
return {"language": "python", "code": arguments}
diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py
index 50916fa3e..3d78c8bfc 100644
--- a/metagpt/provider/qianfan_api.py
+++ b/metagpt/provider/qianfan_api.py
@@ -9,6 +9,7 @@ from qianfan import ChatCompletion
from qianfan.resources.typing import JsonBody
from metagpt.configs.llm_config import LLMConfig, LLMType
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@@ -107,15 +108,15 @@ class QianFanLLM(BaseLLM):
self._update_costs(resp.body.get("usage", {}))
return resp.body
- async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
+ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
- async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
- return await self._achat_completion(messages, timeout=timeout)
+ async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}
diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py
index 882c6ce85..594267259 100644
--- a/metagpt/provider/spark_api.py
+++ b/metagpt/provider/spark_api.py
@@ -17,6 +17,7 @@ from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from metagpt.configs.llm_config import LLMConfig, LLMType
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@@ -31,19 +32,19 @@ class SparkLLM(BaseLLM):
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
- async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass
- async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
+ async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
# 不支持
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
w = GetMessageFromWeb(messages, self.config)
return w.run()
- async def _achat_completion(self, messages: list[dict], timeout=3):
+ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
- async def acompletion(self, messages: list[dict], timeout=3):
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
# 不支持异步
w = GetMessageFromWeb(messages, self.config)
return w.run()
diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py
index 14ad1a36b..2db441991 100644
--- a/metagpt/provider/zhipuai_api.py
+++ b/metagpt/provider/zhipuai_api.py
@@ -8,6 +8,7 @@ from typing import Optional
from zhipuai.types.chat.chat_completion import Completion
from metagpt.configs.llm_config import LLMConfig, LLMType
+from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@@ -45,22 +46,22 @@ class ZhiPuAILLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
- def completion(self, messages: list[dict], timeout=3) -> dict:
+ def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
self._update_costs(usage)
return resp.model_dump()
- async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
+ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
resp = await self.llm.acreate(**self._const_kwargs(messages))
usage = resp.get("usage", {})
self._update_costs(usage)
return resp
- async def acompletion(self, messages: list[dict], timeout=3) -> dict:
- return await self._achat_completion(messages, timeout=timeout)
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
- async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
+ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py
index 15842fdfb..bc3bae662 100644
--- a/metagpt/repo_parser.py
+++ b/metagpt/repo_parser.py
@@ -722,14 +722,19 @@ class RepoParser(BaseModel):
path = Path(path)
if not path.exists():
return
+ init_file = path / "__init__.py"
+ if not init_file.exists():
+ raise ValueError("Failed to import module __init__ with error:No module named __init__.")
command = f"pyreverse {str(path)} -o dot"
- result = subprocess.run(command, shell=True, check=True, cwd=str(path))
+ output_dir = path / "__dot__"
+ output_dir.mkdir(parents=True, exist_ok=True)
+ result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir))
if result.returncode != 0:
raise ValueError(f"{result}")
- class_view_pathname = path / "classes.dot"
+ class_view_pathname = output_dir / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
- packages_pathname = path / "packages.dot"
+ packages_pathname = output_dir / "packages.dot"
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
@@ -975,6 +980,8 @@ class RepoParser(BaseModel):
file_ns = file_ns[0:ix]
continue
break
+ if file_ns == "":
+ return ""
internal_ns = package[ix + 1 :]
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
return ns
diff --git a/metagpt/roles/__init__.py b/metagpt/roles/__init__.py
index f033a5dfa..08a0406b3 100644
--- a/metagpt/roles/__init__.py
+++ b/metagpt/roles/__init__.py
@@ -14,7 +14,6 @@ from metagpt.roles.engineer import Engineer
from metagpt.roles.qa_engineer import QaEngineer
from metagpt.roles.searcher import Searcher
from metagpt.roles.sales import Sales
-from metagpt.roles.customer_service import CustomerService
__all__ = [
@@ -26,5 +25,4 @@ __all__ = [
"QaEngineer",
"Searcher",
"Sales",
- "CustomerService",
]
diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py
index b3a2e789f..a8534b710 100644
--- a/metagpt/roles/di/data_interpreter.py
+++ b/metagpt/roles/di/data_interpreter.py
@@ -1,48 +1,97 @@
from __future__ import annotations
-from pydantic import Field
+import json
+from typing import Literal, Union
+
+from pydantic import Field, model_validator
from metagpt.actions.di.ask_review import ReviewConst
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
-from metagpt.actions.di.write_analysis_code import (
- WriteCodeWithoutTools,
- WriteCodeWithTools,
-)
+from metagpt.actions.di.write_analysis_code import CheckData, WriteAnalysisCode
from metagpt.logs import logger
+from metagpt.prompts.di.write_analysis_code import DATA_INFO
from metagpt.roles import Role
from metagpt.schema import Message, Task, TaskResult
+from metagpt.strategy.task_type import TaskType
+from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
+from metagpt.utils.common import CodeParser
+
+REACT_THINK_PROMPT = """
+# User Requirement
+{user_requirement}
+# Context
+{context}
+
+Output a json following the format:
+```json
+{{
+ "thoughts": str = "Thoughts on current situation, reflect on how you should proceed to fulfill the user requirement",
+ "state": bool = "Decide whether you need to take more actions to complete the user requirement. Return true if you think so. Return false if you think the requirement has been completely fulfilled."
+}}
+```
+"""
class DataInterpreter(Role):
name: str = "David"
profile: str = "DataInterpreter"
auto_run: bool = True
- use_tools: bool = False
+ use_plan: bool = True
+ use_reflection: bool = False
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
- tools: list[str] = []
+ tools: Union[str, list[str]] = [] # Use special symbol [""] to indicate use of all registered tools
+ tool_recommender: ToolRecommender = None
+ react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
+ max_react_loop: int = 10 # used for react mode
- def __init__(
- self,
- auto_run=True,
- use_tools=False,
- tools=[],
- **kwargs,
- ):
- super().__init__(auto_run=auto_run, use_tools=use_tools, tools=tools, **kwargs)
- self._set_react_mode(react_mode="plan_and_act", auto_run=auto_run, use_tools=use_tools)
- if use_tools and tools:
- from metagpt.tools.tool_registry import (
- validate_tool_names, # import upon use
- )
-
- self.tools = validate_tool_names(tools)
- logger.info(f"will only use {self.tools} as tools")
+ @model_validator(mode="after")
+ def set_plan_and_tool(self) -> "Interpreter":
+ self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop, auto_run=self.auto_run)
+ self.use_plan = (
+ self.react_mode == "plan_and_act"
+ ) # create a flag for convenience, overwrite any passed-in value
+ if self.tools:
+ self.tool_recommender = BM25ToolRecommender(tools=self.tools)
+ self.set_actions([WriteAnalysisCode])
+ self._set_state(0)
+ return self
@property
def working_memory(self):
return self.rc.working_memory
+ async def _think(self) -> bool:
+ """Useful in 'react' mode. Use LLM to decide whether and what to do next."""
+ user_requirement = self.get_memories()[0].content
+ context = self.working_memory.get()
+
+ if not context:
+ # just started the run, we need action certainly
+ self.working_memory.add(self.get_memories()[0]) # add user requirement to working memory
+ self._set_state(0)
+ return True
+
+ prompt = REACT_THINK_PROMPT.format(user_requirement=user_requirement, context=context)
+ rsp = await self.llm.aask(prompt)
+ rsp_dict = json.loads(CodeParser.parse_code(block=None, text=rsp))
+ self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant"))
+ need_action = rsp_dict["state"]
+ self._set_state(0) if need_action else self._set_state(-1)
+
+ return need_action
+
+ async def _act(self) -> Message:
+ """Useful in 'react' mode. Return a Message conforming to Role._act interface."""
+ code, _, _ = await self._write_and_exec_code()
+ return Message(content=code, role="assistant", cause_by=WriteAnalysisCode)
+
+ async def _plan_and_act(self) -> Message:
+ rsp = await super()._plan_and_act()
+ await self.execute_code.terminate()
+ return rsp
+
async def _act_on_task(self, current_task: Task) -> TaskResult:
+ """Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation."""
code, result, is_success = await self._write_and_exec_code()
task_result = TaskResult(code=code, result=result, is_success=is_success)
return task_result
@@ -51,14 +100,30 @@ class DataInterpreter(Role):
counter = 0
success = False
+ # plan info
+ plan_status = self.planner.get_plan_status() if self.use_plan else ""
+
+ # tool info
+ if self.tools:
+ context = (
+ self.working_memory.get()[-1].content if self.working_memory.get() else ""
+ ) # thoughts from _think stage in 'react' mode
+ plan = self.planner.plan if self.use_plan else None
+ tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
+ else:
+ tool_info = ""
+
+ # data info
+ await self._check_data()
+
while not success and counter < max_retry:
### write code ###
- code, cause_by = await self._write_code()
+ code, cause_by = await self._write_code(counter, plan_status, tool_info)
- self.working_memory.add(Message(content=code["code"], role="assistant", cause_by=cause_by))
+ self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
### execute code ###
- result, success = await self.execute_code.run(**code)
+ result, success = await self.execute_code.run(code)
print(result)
self.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode))
@@ -72,14 +137,48 @@ class DataInterpreter(Role):
if ReviewConst.CHANGE_WORDS[0] in review:
counter = 0 # redo the task again with help of human suggestions
- return code["code"], result, success
+ return code, result, success
- async def _write_code(self):
- todo = WriteCodeWithoutTools() if not self.use_tools else WriteCodeWithTools(selected_tools=self.tools)
+ async def _write_code(
+ self,
+ counter: int,
+ plan_status: str = "",
+ tool_info: str = "",
+ ):
+ todo = self.rc.todo # todo is WriteAnalysisCode
logger.info(f"ready to {todo.name}")
+ use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial
- context = self.planner.get_useful_memories()
- # print(*context, sep="\n***\n")
- code = await todo.run(context=context, plan=self.planner.plan, temperature=0.0)
+ user_requirement = self.get_memories()[0].content
+
+ code = await todo.run(
+ user_requirement=user_requirement,
+ plan_status=plan_status,
+ tool_info=tool_info,
+ working_memory=self.working_memory.get(),
+ use_reflection=use_reflection,
+ )
return code, todo
+
+ async def _check_data(self):
+ if (
+ not self.use_plan
+ or not self.planner.plan.get_finished_tasks()
+ or self.planner.plan.current_task.task_type
+ not in [
+ TaskType.DATA_PREPROCESS.type_name,
+ TaskType.FEATURE_ENGINEERING.type_name,
+ TaskType.MODEL_TRAIN.type_name,
+ ]
+ ):
+ return
+ logger.info("Check updated data")
+ code = await CheckData().run(self.planner.plan)
+ if not code.strip():
+ return
+ result, success = await self.execute_code.run(code)
+ if success:
+ print(result)
+ data_info = DATA_INFO.format(info=result)
+ self.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData))
diff --git a/metagpt/roles/di/ml_engineer.py b/metagpt/roles/di/ml_engineer.py
deleted file mode 100644
index b33b166cf..000000000
--- a/metagpt/roles/di/ml_engineer.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from metagpt.actions.di.debug_code import DebugCode
-from metagpt.actions.di.execute_nb_code import ExecuteNbCode
-from metagpt.actions.di.ml_action import UpdateDataColumns, WriteCodeWithToolsML
-from metagpt.logs import logger
-from metagpt.roles.di.data_interpreter import DataInterpreter
-from metagpt.tools.tool_type import ToolType
-from metagpt.utils.common import any_to_str
-
-
-class MLEngineer(DataInterpreter):
- name: str = "Mark"
- profile: str = "MLEngineer"
- debug_context: list = []
- latest_code: str = ""
-
- async def _write_code(self):
- if not self.use_tools:
- return await super()._write_code()
-
- # In a trial and errors settings, check whether this is our first attempt to tackle the task. If there is no code execution before, then it is.
- is_first_trial = any_to_str(ExecuteNbCode) not in [msg.cause_by for msg in self.working_memory.get()]
-
- if is_first_trial:
- # For the first trial, write task code from scratch
- column_info = await self._update_data_columns()
-
- logger.info("Write code with tools")
- tool_context, code = await WriteCodeWithToolsML(selected_tools=self.tools).run(
- context=[], # context assembled inside the Action
- plan=self.planner.plan,
- column_info=column_info,
- )
- self.debug_context = tool_context
- cause_by = WriteCodeWithToolsML
-
- else:
- # Previous trials resulted in error, debug and rewrite the code
- logger.warning("We got a bug, now start to debug...")
- code = await DebugCode().run(
- code=self.latest_code,
- runtime_result=self.working_memory.get(),
- context=self.debug_context,
- )
- logger.info(f"new code \n{code}")
- cause_by = DebugCode
-
- self.latest_code = code["code"]
-
- return code, cause_by
-
- async def _update_data_columns(self):
- current_task = self.planner.plan.current_task
- if current_task.task_type not in [
- ToolType.DATA_PREPROCESS.type_name,
- ToolType.FEATURE_ENGINEERING.type_name,
- ToolType.MODEL_TRAIN.type_name,
- ]:
- return ""
- logger.info("Check columns in updated data")
- code = await UpdateDataColumns().run(self.planner.plan)
- success = False
- result, success = await self.execute_code.run(**code)
- print(result)
- return result if success else ""
diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py
index 329b21553..9d8f6884f 100644
--- a/metagpt/roles/engineer.py
+++ b/metagpt/roles/engineer.py
@@ -240,8 +240,8 @@ class Engineer(Role):
async def _think(self) -> Action | None:
if not self.src_workspace:
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
- write_plan_and_change_filters = any_to_str_set([WriteTasks])
- write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode, FixBug])
+ write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug])
+ write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
if not self.rc.news:
return None
diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py
index 4e8cc5cb8..e0f8a7ea6 100644
--- a/metagpt/roles/role.py
+++ b/metagpt/roles/role.py
@@ -169,6 +169,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self._check_actions()
self.llm.system_prompt = self._get_prefix()
+ self.llm.cost_manager = self.context.cost_manager
self._watch(kwargs.pop("watch", [UserRequirement]))
if self.latest_observed_msg:
@@ -277,7 +278,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self.actions.append(i)
self.states.append(f"{len(self.actions) - 1}. {action}")
- def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
+ def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True):
"""Set strategy of the Role reacting to observed Message. Variation lies in how
this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions.
@@ -298,9 +299,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if react_mode == RoleReactMode.REACT:
self.rc.max_react_loop = max_react_loop
elif react_mode == RoleReactMode.PLAN_AND_ACT:
- self.planner = Planner(
- goal=self.goal, working_memory=self.rc.working_memory, auto_run=auto_run, use_tools=use_tools
- )
+ self.planner = Planner(goal=self.goal, working_memory=self.rc.working_memory, auto_run=auto_run)
def _watch(self, actions: Iterable[Type[Action]] | Iterable[Action]):
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
@@ -333,6 +332,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if env:
env.set_addresses(self, self.addresses)
self.llm.system_prompt = self._get_prefix()
+ self.llm.cost_manager = self.context.cost_manager
self.set_actions(self.actions) # reset actions to update llm and prefix
def _get_prefix(self):
diff --git a/metagpt/startup.py b/metagpt/startup.py
new file mode 100644
index 000000000..bb6f6abf2
--- /dev/null
+++ b/metagpt/startup.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2024/3/11 19:16
+@Author : alexanderwu
+@File : startup.py
+"""
+
+# DEPRECATED: This file is deprecated and will be removed in the future.
+# The startup.py implementation has been moved to software_company.py
diff --git a/metagpt/strategy/planner.py b/metagpt/strategy/planner.py
index 44294be00..fbf784837 100644
--- a/metagpt/strategy/planner.py
+++ b/metagpt/strategy/planner.py
@@ -13,6 +13,8 @@ from metagpt.actions.di.write_plan import (
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.schema import Message, Plan, Task, TaskResult
+from metagpt.strategy.task_type import TaskType
+from metagpt.utils.common import remove_comments
STRUCTURAL_CONTEXT = """
## User Requirement
@@ -25,6 +27,24 @@ STRUCTURAL_CONTEXT = """
{current_task}
"""
+PLAN_STATUS = """
+## Finished Tasks
+### code
+```python
+{code_written}
+```
+
+### execution result
+{task_results}
+
+## Current Task
+{current_task}
+
+## Task Guidance
+Write complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.
+Specifically, {guidance}
+"""
+
class Planner(BaseModel):
plan: Plan
@@ -32,7 +52,6 @@ class Planner(BaseModel):
default_factory=Memory
) # memory for working on each task, discarded each time a task is done
auto_run: bool = False
- use_tools: bool = False
def __init__(self, goal: str = "", plan: Plan = None, **kwargs):
plan = plan or Plan(goal=goal)
@@ -53,7 +72,7 @@ class Planner(BaseModel):
plan_confirmed = False
while not plan_confirmed:
context = self.get_useful_memories()
- rsp = await WritePlan().run(context, max_tasks=max_tasks, use_tools=self.use_tools)
+ rsp = await WritePlan().run(context, max_tasks=max_tasks)
self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan))
# precheck plan before asking reviews
@@ -137,3 +156,24 @@ class Planner(BaseModel):
context_msg = [Message(content=context, role="user")]
return context_msg + self.working_memory.get()
+
+ def get_plan_status(self) -> str:
+ # prepare components of a plan status
+ finished_tasks = self.plan.get_finished_tasks()
+ code_written = [remove_comments(task.code) for task in finished_tasks]
+ code_written = "\n\n".join(code_written)
+ task_results = [task.result for task in finished_tasks]
+ task_results = "\n\n".join(task_results)
+ task_type_name = self.current_task.task_type
+ task_type = TaskType.get_type(task_type_name)
+ guidance = task_type.guidance if task_type else ""
+
+ # combine components in a prompt
+ prompt = PLAN_STATUS.format(
+ code_written=code_written,
+ task_results=task_results,
+ current_task=self.current_task.instruction,
+ guidance=guidance,
+ )
+
+ return prompt
diff --git a/metagpt/strategy/task_type.py b/metagpt/strategy/task_type.py
new file mode 100644
index 000000000..d21705c16
--- /dev/null
+++ b/metagpt/strategy/task_type.py
@@ -0,0 +1,80 @@
+from enum import Enum
+
+from pydantic import BaseModel
+
+from metagpt.prompts.task_type import (
+ DATA_PREPROCESS_PROMPT,
+ EDA_PROMPT,
+ FEATURE_ENGINEERING_PROMPT,
+ IMAGE2WEBPAGE_PROMPT,
+ MODEL_EVALUATE_PROMPT,
+ MODEL_TRAIN_PROMPT,
+)
+
+
+class TaskTypeDef(BaseModel):
+ name: str
+ desc: str = ""
+ guidance: str = ""
+
+
+class TaskType(Enum):
+ """By identifying specific types of tasks, we can inject human priors (guidance) to help task solving"""
+
+ EDA = TaskTypeDef(
+ name="eda",
+ desc="For performing exploratory data analysis",
+ guidance=EDA_PROMPT,
+ )
+ DATA_PREPROCESS = TaskTypeDef(
+ name="data preprocessing",
+ desc="For preprocessing dataset in a data analysis or machine learning task ONLY,"
+ "general data operation doesn't fall into this type",
+ guidance=DATA_PREPROCESS_PROMPT,
+ )
+ FEATURE_ENGINEERING = TaskTypeDef(
+ name="feature engineering",
+ desc="Only for creating new columns for input data.",
+ guidance=FEATURE_ENGINEERING_PROMPT,
+ )
+ MODEL_TRAIN = TaskTypeDef(
+ name="model train",
+ desc="Only for training model.",
+ guidance=MODEL_TRAIN_PROMPT,
+ )
+ MODEL_EVALUATE = TaskTypeDef(
+ name="model evaluate",
+ desc="Only for evaluating model.",
+ guidance=MODEL_EVALUATE_PROMPT,
+ )
+ IMAGE2WEBPAGE = TaskTypeDef(
+ name="image2webpage",
+ desc="For converting image into webpage code.",
+ guidance=IMAGE2WEBPAGE_PROMPT,
+ )
+ OTHER = TaskTypeDef(name="other", desc="Any tasks not in the defined categories")
+
+ # Legacy TaskType to support tool recommendation using type match. You don't need to define task types if you have no human priors to inject.
+ TEXT2IMAGE = TaskTypeDef(
+ name="text2image",
+ desc="Related to text2image, image2image using stable diffusion model.",
+ )
+ WEBSCRAPING = TaskTypeDef(
+ name="web scraping",
+ desc="For scraping data from web pages.",
+ )
+ EMAIL_LOGIN = TaskTypeDef(
+ name="email login",
+ desc="For logging to an email.",
+ )
+
+ @property
+ def type_name(self):
+ return self.value.name
+
+ @classmethod
+ def get_type(cls, type_name):
+ for member in cls:
+ if member.type_name == type_name:
+ return member.value
+ return None
diff --git a/metagpt/tools/libs/data_preprocess.py b/metagpt/tools/libs/data_preprocess.py
index 7a3d019bf..aa9070689 100644
--- a/metagpt/tools/libs/data_preprocess.py
+++ b/metagpt/tools/libs/data_preprocess.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import json
+from typing import Literal
import numpy as np
import pandas as pd
@@ -16,9 +17,8 @@ from sklearn.preprocessing import (
)
from metagpt.tools.tool_registry import register_tool
-from metagpt.tools.tool_type import ToolType
-TOOL_TYPE = ToolType.DATA_PREPROCESS.type_name
+TAGS = ["data preprocessing", "machine learning"]
class MLProcess:
@@ -85,20 +85,22 @@ class DataPreprocessTool(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class FillMissingValue(DataPreprocessTool):
"""
Completing missing values with simple strategies.
"""
- def __init__(self, features: list, strategy: str = "mean", fill_value=None):
+ def __init__(
+ self, features: list, strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", fill_value=None
+ ):
"""
Initialize self.
Args:
features (list): Columns to be processed.
- strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only
- be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
+ strategy (Literal["mean", "median", "most_frequent", "constant"], optional): The imputation strategy, notice 'mean' and 'median' can only
+ be used for numeric features. Defaults to 'mean'.
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
Defaults to None.
"""
@@ -106,7 +108,7 @@ class FillMissingValue(DataPreprocessTool):
self.model = SimpleImputer(strategy=strategy, fill_value=fill_value)
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class MinMaxScale(DataPreprocessTool):
"""
Transform features by scaling each feature to a range, which is (0, 1).
@@ -117,7 +119,7 @@ class MinMaxScale(DataPreprocessTool):
self.model = MinMaxScaler()
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class StandardScale(DataPreprocessTool):
"""
Standardize features by removing the mean and scaling to unit variance.
@@ -128,7 +130,7 @@ class StandardScale(DataPreprocessTool):
self.model = StandardScaler()
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class MaxAbsScale(DataPreprocessTool):
"""
Scale each feature by its maximum absolute value.
@@ -139,7 +141,7 @@ class MaxAbsScale(DataPreprocessTool):
self.model = MaxAbsScaler()
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class RobustScale(DataPreprocessTool):
"""
Apply the RobustScaler to scale features using statistics that are robust to outliers.
@@ -150,7 +152,7 @@ class RobustScale(DataPreprocessTool):
self.model = RobustScaler()
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class OrdinalEncode(DataPreprocessTool):
"""
Encode categorical features as ordinal integers.
@@ -161,7 +163,7 @@ class OrdinalEncode(DataPreprocessTool):
self.model = OrdinalEncoder()
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class OneHotEncode(DataPreprocessTool):
"""
Apply one-hot encoding to specified categorical columns, the original columns will be dropped.
@@ -180,7 +182,7 @@ class OneHotEncode(DataPreprocessTool):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class LabelEncode(DataPreprocessTool):
"""
Apply label encoding to specified categorical columns in-place.
diff --git a/metagpt/tools/libs/email_login.py b/metagpt/tools/libs/email_login.py
index 79734665e..32626ac55 100644
--- a/metagpt/tools/libs/email_login.py
+++ b/metagpt/tools/libs/email_login.py
@@ -1,7 +1,6 @@
from imap_tools import MailBox
from metagpt.tools.tool_registry import register_tool
-from metagpt.tools.tool_type import ToolType
# Define a dictionary mapping email domains to their IMAP server addresses
IMAP_SERVERS = {
@@ -24,7 +23,7 @@ IMAP_SERVERS = {
}
-@register_tool(tool_type=ToolType.EMAIL_LOGIN.type_name)
+@register_tool(tags=["email login"])
def email_login_imap(email_address, email_password):
"""
Use imap_tools package to log in to your email (the email that supports IMAP protocol) to verify and return the account object.
diff --git a/metagpt/tools/libs/feature_engineering.py b/metagpt/tools/libs/feature_engineering.py
index 40bfb2fc7..3013e1594 100644
--- a/metagpt/tools/libs/feature_engineering.py
+++ b/metagpt/tools/libs/feature_engineering.py
@@ -19,12 +19,11 @@ from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
from metagpt.tools.libs.data_preprocess import MLProcess
from metagpt.tools.tool_registry import register_tool
-from metagpt.tools.tool_type import ToolType
-TOOL_TYPE = ToolType.FEATURE_ENGINEERING.type_name
+TAGS = ["feature engineering", "machine learning"]
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class PolynomialExpansion(MLProcess):
"""
Add polynomial and interaction features from selected numeric columns to input DataFrame.
@@ -67,7 +66,7 @@ class PolynomialExpansion(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class CatCount(MLProcess):
"""
Add value counts of a categorical column as new feature.
@@ -92,7 +91,7 @@ class CatCount(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class TargetMeanEncoder(MLProcess):
"""
Encode a categorical column by the mean of the label column, and adds the result as a new feature.
@@ -119,7 +118,7 @@ class TargetMeanEncoder(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class KFoldTargetMeanEncoder(MLProcess):
"""
Add a new feature to the DataFrame by k-fold mean encoding of a categorical column using the label column.
@@ -159,7 +158,7 @@ class KFoldTargetMeanEncoder(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class CatCross(MLProcess):
"""
Add pairwise crossed features and convert them to numerical features.
@@ -216,7 +215,7 @@ class CatCross(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class GroupStat(MLProcess):
"""
Aggregate specified column in a DataFrame grouped by another column, adding new features named '__by_'.
@@ -248,7 +247,7 @@ class GroupStat(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class SplitBins(MLProcess):
"""
Inplace binning of continuous data into intervals, returning integer-encoded bin identifiers directly.
@@ -276,7 +275,7 @@ class SplitBins(MLProcess):
return new_df
-# @register_tool(tool_type=TOOL_TYPE)
+# @register_tool(tags=TAGS)
class ExtractTimeComps(MLProcess):
"""
Extract time components from a datetime column and add them as new features.
@@ -316,7 +315,7 @@ class ExtractTimeComps(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class GeneralSelection(MLProcess):
"""
Drop all nan feats and feats with only one unique value.
@@ -349,7 +348,7 @@ class GeneralSelection(MLProcess):
# skip for now because lgb is needed
-# @register_tool(tool_type=TOOL_TYPE)
+# @register_tool(tags=TAGS)
class TreeBasedSelection(MLProcess):
"""
Select features based on tree-based model and remove features with low importance.
@@ -403,7 +402,7 @@ class TreeBasedSelection(MLProcess):
return new_df
-@register_tool(tool_type=TOOL_TYPE)
+@register_tool(tags=TAGS)
class VarianceBasedSelection(MLProcess):
"""
Select features based on variance and remove features with low variance.
diff --git a/metagpt/tools/libs/gpt_v_generator.py b/metagpt/tools/libs/gpt_v_generator.py
index 3b17fc596..4eba3d5ee 100644
--- a/metagpt/tools/libs/gpt_v_generator.py
+++ b/metagpt/tools/libs/gpt_v_generator.py
@@ -5,13 +5,13 @@
@Author : mannaandpoem
@File : gpt_v_generator.py
"""
-import os
+import re
from pathlib import Path
from metagpt.const import DEFAULT_WORKSPACE_ROOT
+from metagpt.logs import logger
from metagpt.tools.tool_registry import register_tool
-from metagpt.tools.tool_type import ToolType
-from metagpt.utils.common import encode_image
+from metagpt.utils.common import CodeParser, encode_image
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX designer, please generate layout information for this image:
@@ -28,11 +28,9 @@ As the design pays tribute to large companies, sometimes it is normal for some c
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
-@register_tool(
- tool_type=ToolType.IMAGE2WEBPAGE.type_name, include_functions=["__init__", "generate_webpages", "save_webpages"]
-)
+@register_tool(tags=["image2webpage"], include_functions=["__init__", "generate_webpages", "save_webpages"])
class GPTvGenerator:
- """Class for generating webpages at once.
+ """Class for generating webpage code from a given webpage screenshot.
This class provides methods to generate webpages including all code (HTML, CSS, and JavaScript) based on an image.
It utilizes a vision model to analyze the layout from an image and generate webpage codes accordingly.
@@ -75,50 +73,34 @@ class GPTvGenerator:
return await self.llm.aask(msg=prompt, images=[encode_image(image_path)])
@staticmethod
- def save_webpages(image_path: str, webpages: str) -> Path:
+ def save_webpages(webpages: str, save_folder_name: str = "example") -> Path:
"""Save webpages including all code (HTML, CSS, and JavaScript) at once.
Args:
- image_path (str): The path of the image file.
webpages (str): The generated webpages content.
+ save_folder_name (str, optional): The name of the folder to save the webpages. Defaults to 'example'.
Returns:
Path: The path of the saved webpages.
"""
# Create a folder called webpages in the workspace directory to store HTML, CSS, and JavaScript files
- webpages_path = DEFAULT_WORKSPACE_ROOT / "webpages" / Path(image_path).stem
- os.makedirs(webpages_path, exist_ok=True)
+ webpages_path = DEFAULT_WORKSPACE_ROOT / "webpages" / save_folder_name
+ logger.info(f"code will be saved at {webpages_path}")
+ webpages_path.mkdir(parents=True, exist_ok=True)
index_path = webpages_path / "index.html"
- try:
- index = webpages.split("```html")[1].split("```")[0]
- style_path = None
- if "styles.css" in index:
- style_path = webpages_path / "styles.css"
- elif "style.css" in index:
- style_path = webpages_path / "style.css"
- style = webpages.split("```css")[1].split("```")[0] if style_path else ""
+ index_path.write_text(CodeParser.parse_code(block=None, text=webpages, lang="html"))
- js_path = None
- if "scripts.js" in index:
- js_path = webpages_path / "scripts.js"
- elif "script.js" in index:
- js_path = webpages_path / "script.js"
+ extract_and_save_code(folder=webpages_path, text=webpages, pattern="styles?.css", language="css")
- js = webpages.split("```javascript")[1].split("```")[0] if js_path else ""
- except IndexError:
- raise ValueError(f"No html or css or js code found in the result. \nWebpages: {webpages}")
-
- try:
- with open(index_path, "w", encoding="utf-8") as f:
- f.write(index)
- if style_path:
- with open(style_path, "w", encoding="utf-8") as f:
- f.write(style)
- if js_path:
- with open(js_path, "w", encoding="utf-8") as f:
- f.write(js)
- except FileNotFoundError as e:
- raise FileNotFoundError(f"Cannot save the webpages to {str(webpages_path)}") from e
+ extract_and_save_code(folder=webpages_path, text=webpages, pattern="scripts?.js", language="javascript")
return webpages_path
+
+
+def extract_and_save_code(folder, text, pattern, language):
+ word = re.search(pattern, text)
+ if word:
+ path = folder / word.group(0)
+ code = CodeParser.parse_code(block=None, text=text, lang=language)
+ path.write_text(code, encoding="utf-8")
diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py
index 347f4a430..b62e39db8 100644
--- a/metagpt/tools/libs/sd_engine.py
+++ b/metagpt/tools/libs/sd_engine.py
@@ -14,11 +14,9 @@ import requests
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
-#
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
from metagpt.logs import logger
from metagpt.tools.tool_registry import register_tool
-from metagpt.tools.tool_type import ToolType
payload = {
"prompt": "",
@@ -55,7 +53,7 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
@register_tool(
- tool_type=ToolType.STABLE_DIFFUSION.type_name,
+ tags=["text2image", "multimodal"],
include_functions=["__init__", "simple_run_t2i", "run_t2i", "construct_payload", "save"],
)
class SDEngine:
diff --git a/metagpt/tools/libs/web_scraping.py b/metagpt/tools/libs/web_scraping.py
index d01e69d09..bc34b1306 100644
--- a/metagpt/tools/libs/web_scraping.py
+++ b/metagpt/tools/libs/web_scraping.py
@@ -1,9 +1,8 @@
from metagpt.tools.tool_registry import register_tool
-from metagpt.tools.tool_type import ToolType
from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper
-@register_tool(tool_type=ToolType.WEBSCRAPING.type_name)
+@register_tool(tags=["web scraping", "web"])
async def scrape_web_playwright(url):
"""
Asynchronously Scrape and save the HTML structure and inner text content of a web page using Playwright.
diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py
index fc29d0693..42c65b9e7 100644
--- a/metagpt/tools/tool_convert.py
+++ b/metagpt/tools/tool_convert.py
@@ -2,14 +2,18 @@ import inspect
from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
+PARSER = GoogleDocstringParser
-def convert_code_to_tool_schema(obj, include: list[str] = []):
+
+def convert_code_to_tool_schema(obj, include: list[str] = None):
docstring = inspect.getdoc(obj)
assert docstring, "no docstring found for the objects, skip registering"
if inspect.isclass(obj):
schema = {"type": "class", "description": remove_spaces(docstring), "methods": {}}
for name, method in inspect.getmembers(obj, inspect.isfunction):
+ if name.startswith("_") and name != "__init__": # skip private methodss
+ continue
if include and name not in include:
continue
# method_doc = inspect.getdoc(method)
@@ -23,54 +27,31 @@ def convert_code_to_tool_schema(obj, include: list[str] = []):
return schema
-def function_docstring_to_schema(fn_obj, docstring):
+def function_docstring_to_schema(fn_obj, docstring) -> dict:
+ """
+ Converts a function's docstring into a schema dictionary.
+
+ Args:
+ fn_obj: The function object.
+ docstring: The docstring of the function.
+
+ Returns:
+ A dictionary representing the schema of the function's docstring.
+ The dictionary contains the following keys:
+ - 'type': The type of the function ('function' or 'async_function').
+ - 'description': The first section of the docstring describing the function overall. Provided to LLMs for both recommending and using the function.
+ - 'signature': The signature of the function, which helps LLMs understand how to call the function.
+ - 'parameters': Docstring section describing parameters including args and returns, served as extra details for LLM perception.
+ """
+ signature = inspect.signature(fn_obj)
+
+ docstring = remove_spaces(docstring)
+
+ overall_desc, param_desc = PARSER.parse(docstring)
+
function_type = "function" if not inspect.iscoroutinefunction(fn_obj) else "async_function"
- return {"type": function_type, **docstring_to_schema(docstring)}
-
-def docstring_to_schema(docstring: str):
- if docstring is None:
- return {}
-
- parser = GoogleDocstringParser(docstring=docstring)
-
- # 匹配简介部分
- description = parser.parse_desc()
-
- # 匹配Args部分
- params = parser.parse_params()
- parameter_schema = {"properties": {}, "required": []}
- for param in params:
- param_name, param_type, param_desc = param
- # check required or optional
- is_optional, param_type = parser.check_and_parse_optional(param_type)
- if not is_optional:
- parameter_schema["required"].append(param_name)
- # type and desc
- param_dict = {"type": param_type, "description": remove_spaces(param_desc)}
- # match Default for optional args
- has_default_val, default_val = parser.check_and_parse_default_value(param_desc)
- if has_default_val:
- param_dict["default"] = default_val
- # match Enum
- has_enum, enum_vals = parser.check_and_parse_enum(param_desc)
- if has_enum:
- param_dict["enum"] = enum_vals
- # add to parameter schema
- parameter_schema["properties"].update({param_name: param_dict})
-
- # 匹配Returns部分
- returns = parser.parse_returns()
-
- # 构建YAML字典
- schema = {
- "description": description,
- "parameters": parameter_schema,
- }
- if returns:
- schema["returns"] = [{"type": ret[0], "description": remove_spaces(ret[1])} for ret in returns]
-
- return schema
+ return {"type": function_type, "description": overall_desc, "signature": str(signature), "parameters": param_desc}
def get_class_method_docstring(cls, method_name):
diff --git a/metagpt/tools/tool_data_type.py b/metagpt/tools/tool_data_type.py
index 0ae46fa5c..1a31b03e7 100644
--- a/metagpt/tools/tool_data_type.py
+++ b/metagpt/tools/tool_data_type.py
@@ -1,12 +1,6 @@
from pydantic import BaseModel
-class ToolTypeDef(BaseModel):
- name: str
- desc: str = ""
- usage_prompt: str = ""
-
-
class ToolSchema(BaseModel):
description: str
@@ -16,3 +10,4 @@ class Tool(BaseModel):
path: str
schemas: dict = {}
code: str = ""
+ tags: list[str] = []
diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py
new file mode 100644
index 000000000..69b9a4b5d
--- /dev/null
+++ b/metagpt/tools/tool_recommend.py
@@ -0,0 +1,214 @@
+from __future__ import annotations
+
+import json
+from typing import Any
+
+import jieba
+import numpy as np
+from pydantic import BaseModel, field_validator
+from rank_bm25 import BM25Okapi
+
+from metagpt.llm import LLM
+from metagpt.logs import logger
+from metagpt.schema import Plan
+from metagpt.tools import TOOL_REGISTRY
+from metagpt.tools.tool_data_type import Tool
+from metagpt.tools.tool_registry import validate_tool_names
+from metagpt.utils.common import CodeParser
+
+TOOL_INFO_PROMPT = """
+## Capabilities
+- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python class or function.
+- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
+
+## Available Tools:
+Each tool is described in JSON format. When you call a tool, import the tool from its path first.
+{tool_schemas}
+"""
+
+
+TOOL_RECOMMENDATION_PROMPT = """
+## User Requirement:
+{current_task}
+
+## Task
+Recommend up to {topk} tools from 'Available Tools' that can help solve the 'User Requirement'.
+
+## Available Tools:
+{available_tools}
+
+## Tool Selection and Instructions:
+- Select tools most relevant to completing the 'User Requirement'.
+- If you believe that no tools are suitable, indicate with an empty list.
+- Only list the names of the tools, not the full schema of each tool.
+- Ensure selected tools are listed in 'Available Tools'.
+- Output a json list of tool names:
+```json
+["tool_name1", "tool_name2", ...]
+```
+"""
+
+
+class ToolRecommender(BaseModel):
+ """
+ The default ToolRecommender:
+ 1. Recall: To be implemented in subclasses. Recall tools based on the given context and plan.
+ 2. Rank: Use LLM to select final candidates from recalled set.
+ """
+
+ tools: dict[str, Tool] = {}
+ force: bool = False # whether to forcedly recommend the specified tools
+
+ @field_validator("tools", mode="before")
+ @classmethod
+ def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
+ # One can use special symbol [""] to indicate use of all registered tools
+ if v == [""]:
+ return TOOL_REGISTRY.get_all_tools()
+ else:
+ return validate_tool_names(v)
+
+ async def recommend_tools(
+ self, context: str = "", plan: Plan = None, recall_topk: int = 20, topk: int = 5
+ ) -> list[Tool]:
+ """
+ Recommends a list of tools based on the given context and plan. The recommendation process includes two stages: recall from a large pool and rank the recalled tools to select the final set.
+
+ Args:
+ context (str): The context for tool recommendation.
+ plan (Plan): The plan for tool recommendation.
+ recall_topk (int): The number of tools to recall in the initial step.
+ topk (int): The number of tools to return after rank as final recommendations.
+
+ Returns:
+ list[Tool]: A list of recommended tools.
+ """
+
+ if not self.tools:
+ return []
+
+ if self.force or (not context and not plan):
+ # directly use what users have specified as result for forced recommendation;
+ # directly use the whole set if there is no useful information
+ return list(self.tools.values())
+
+ recalled_tools = await self.recall_tools(context=context, plan=plan, topk=recall_topk)
+ if not recalled_tools:
+ return []
+
+ ranked_tools = await self.rank_tools(recalled_tools=recalled_tools, context=context, plan=plan, topk=topk)
+
+ logger.info(f"Recommended tools: \n{[tool.name for tool in ranked_tools]}")
+
+ return ranked_tools
+
+ async def get_recommended_tool_info(self, **kwargs) -> str:
+ """
+ Wrap recommended tools with their info in a string, which can be used directly in a prompt.
+ """
+ recommended_tools = await self.recommend_tools(**kwargs)
+ if not recommended_tools:
+ return ""
+ tool_schemas = {tool.name: tool.schemas for tool in recommended_tools}
+ return TOOL_INFO_PROMPT.format(tool_schemas=tool_schemas)
+
+ async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
+ """
+ Retrieves a list of relevant tools from a large pool, based on the given context and plan.
+ """
+ raise NotImplementedError
+
+ async def rank_tools(
+ self, recalled_tools: list[Tool], context: str = "", plan: Plan = None, topk: int = 5
+ ) -> list[Tool]:
+ """
+ Default rank methods for a ToolRecommender. Use LLM to rank the recalled tools based on the given context, plan, and topk value.
+ """
+ current_task = plan.current_task.instruction if plan else context
+
+ available_tools = {tool.name: tool.schemas["description"] for tool in recalled_tools}
+ prompt = TOOL_RECOMMENDATION_PROMPT.format(
+ current_task=current_task,
+ available_tools=available_tools,
+ topk=topk,
+ )
+ rsp = await LLM().aask(prompt)
+ rsp = CodeParser.parse_code(block=None, text=rsp)
+ ranked_tools = json.loads(rsp)
+
+ valid_tools = validate_tool_names(ranked_tools)
+
+ return list(valid_tools.values())[:topk]
+
+
+class TypeMatchToolRecommender(ToolRecommender):
+ """
+ A legacy ToolRecommender using task type matching at the recall stage:
+ 1. Recall: Find tools based on exact match between task type and tool tag;
+ 2. Rank: LLM rank, the same as the default ToolRecommender.
+ """
+
+ async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
+ if not plan:
+ return list(self.tools.values())[:topk]
+
+ # find tools based on exact match between task type and tool tag
+ task_type = plan.current_task.task_type
+ candidate_tools = TOOL_REGISTRY.get_tools_by_tag(task_type)
+ candidate_tool_names = set(self.tools.keys()) & candidate_tools.keys()
+ recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names][:topk]
+
+ logger.info(f"Recalled tools: \n{[tool.name for tool in recalled_tools]}")
+
+ return recalled_tools
+
+
+class BM25ToolRecommender(ToolRecommender):
+ """
+ A ToolRecommender using BM25 at the recall stage:
+ 1. Recall: Querying tool descriptions with task instruction if plan exists. Otherwise, return all user-specified tools;
+ 2. Rank: LLM rank, the same as the default ToolRecommender.
+ """
+
+ bm25: Any = None
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._init_corpus()
+
+ def _init_corpus(self):
+ corpus = [f"{tool.name} {tool.tags}: {tool.schemas['description']}" for tool in self.tools.values()]
+ tokenized_corpus = [self._tokenize(doc) for doc in corpus]
+ self.bm25 = BM25Okapi(tokenized_corpus)
+
+ def _tokenize(self, text):
+ return jieba.lcut(text) # FIXME: needs more sophisticated tokenization
+
+ async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
+ query = plan.current_task.instruction if plan else context
+
+ query_tokens = self._tokenize(query)
+ doc_scores = self.bm25.get_scores(query_tokens)
+ top_indexes = np.argsort(doc_scores)[::-1][:topk]
+ recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
+
+ logger.info(
+ f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}"
+ )
+
+ return recalled_tools
+
+
+class EmbeddingToolRecommender(ToolRecommender):
+ """
+ NOTE: To be implemented.
+ A ToolRecommender using embeddings at the recall stage:
+ 1. Recall: Use embeddings to calculate the similarity between query and tool info;
+ 2. Rank: LLM rank, the same as the default ToolRecommender.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
+ pass
diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py
index 02eecb0f3..11269cb0f 100644
--- a/metagpt/tools/tool_registry.py
+++ b/metagpt/tools/tool_registry.py
@@ -10,26 +10,20 @@ from __future__ import annotations
import inspect
import os
from collections import defaultdict
+from typing import Union
import yaml
-from pydantic import BaseModel, field_validator
+from pydantic import BaseModel
from metagpt.const import TOOL_SCHEMA_PATH
from metagpt.logs import logger
from metagpt.tools.tool_convert import convert_code_to_tool_schema
-from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolTypeDef
-from metagpt.tools.tool_type import ToolType
+from metagpt.tools.tool_data_type import Tool, ToolSchema
class ToolRegistry(BaseModel):
tools: dict = {}
- tool_types: dict = {}
- tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
-
- @field_validator("tool_types", mode="before")
- @classmethod
- def init_tool_types(cls, tool_types: ToolType):
- return {tool_type.type_name: tool_type.value for tool_type in tool_types}
+ tools_by_tags: dict = defaultdict(dict) # two-layer k-v, {tag: {tool_name: {...}, ...}, ...}
def register_tool(
self,
@@ -37,25 +31,15 @@ class ToolRegistry(BaseModel):
tool_path,
schema_path="",
tool_code="",
- tool_type="other",
+ tags=None,
tool_source_object=None,
- include_functions=[],
+ include_functions=None,
verbose=False,
):
if self.has_tool(tool_name):
return
- if tool_type not in self.tool_types:
- # register new tool type on the fly
- logger.warning(
- f"{tool_type} not previously defined, will create a temporary tool type with just a name. This tool type is only effective during this runtime. You may consider add this tool type with more configs permanently at metagpt.tools.tool_type"
- )
- temp_tool_type_obj = ToolTypeDef(name=tool_type)
- self.tool_types[tool_type] = temp_tool_type_obj
- if verbose:
- logger.info(f"tool type {tool_type} registered")
-
- schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
+ schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml"
schemas = make_schema(tool_source_object, include_functions, schema_path)
@@ -70,10 +54,11 @@ class ToolRegistry(BaseModel):
# logger.warning(
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
# )
-
- tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
+ tags = tags or []
+ tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code, tags=tags)
self.tools[tool_name] = tool
- self.tools_by_types[tool_type][tool_name] = tool
+ for tag in tags:
+ self.tools_by_tags[tag].update({tool_name: tool})
if verbose:
logger.info(f"{tool_name} registered")
logger.info(f"schema made at {str(schema_path)}, can be used for checking")
@@ -84,24 +69,24 @@ class ToolRegistry(BaseModel):
def get_tool(self, key) -> Tool:
return self.tools.get(key)
- def get_tools_by_type(self, key) -> dict[str, Tool]:
- return self.tools_by_types.get(key, {})
+ def get_tools_by_tag(self, key) -> dict[str, Tool]:
+ return self.tools_by_tags.get(key, {})
- def has_tool_type(self, key) -> bool:
- return key in self.tool_types
+ def get_all_tools(self) -> dict[str, Tool]:
+ return self.tools
- def get_tool_type(self, key) -> ToolType:
- return self.tool_types.get(key)
+ def has_tool_tag(self, key) -> bool:
+ return key in self.tools_by_tags
- def get_tool_types(self) -> dict[str, ToolType]:
- return self.tool_types
+ def get_tool_tags(self) -> list[str]:
+ return list(self.tools_by_tags.keys())
# Registry instance
-TOOL_REGISTRY = ToolRegistry(tool_types=ToolType)
+TOOL_REGISTRY = ToolRegistry()
-def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
+def register_tool(tags: list[str] = None, schema_path: str = "", **kwargs):
"""register a tool to registry"""
def decorator(cls):
@@ -117,7 +102,7 @@ def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
tool_path=file_path,
schema_path=schema_path,
tool_code=source_code,
- tool_type=tool_type,
+ tags=tags,
tool_source_object=cls,
**kwargs,
)
@@ -142,14 +127,15 @@ def make_schema(tool_source_object, include, path):
return schema
-def validate_tool_names(tools: list[str], return_tool_object=False) -> list[str]:
- valid_tools = []
- for tool_name in tools:
- if not TOOL_REGISTRY.has_tool(tool_name):
- logger.warning(
- f"Specified tool {tool_name} not found and was skipped. Check if you have registered it properly"
- )
+def validate_tool_names(tools: Union[list[str], str]) -> str:
+ assert isinstance(tools, list), "tools must be a list of str"
+ valid_tools = {}
+ for key in tools:
+ # one can define either tool names or tool type names, take union to get the whole set
+ if TOOL_REGISTRY.has_tool(key):
+ valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
+ elif TOOL_REGISTRY.has_tool_tag(key):
+ valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
else:
- valid_tool = TOOL_REGISTRY.get_tool(tool_name) if return_tool_object else tool_name
- valid_tools.append(valid_tool)
+ logger.warning(f"invalid tool name or tool type name: {key}, skipped")
return valid_tools
diff --git a/metagpt/tools/tool_type.py b/metagpt/tools/tool_type.py
deleted file mode 100644
index e9c9a3482..000000000
--- a/metagpt/tools/tool_type.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from enum import Enum
-
-from metagpt.prompts.tool_types import (
- DATA_PREPROCESS_PROMPT,
- EDA_PROMPT,
- FEATURE_ENGINEERING_PROMPT,
- IMAGE2WEBPAGE_PROMPT,
- MODEL_EVALUATE_PROMPT,
- MODEL_TRAIN_PROMPT,
-)
-from metagpt.tools.tool_data_type import ToolTypeDef
-
-
-class ToolType(Enum):
- EDA = ToolTypeDef(
- name="eda",
- desc="For performing exploratory data analysis",
- usage_prompt=EDA_PROMPT,
- )
- DATA_PREPROCESS = ToolTypeDef(
- name="data_preprocess",
- desc="Only for changing value inplace.",
- usage_prompt=DATA_PREPROCESS_PROMPT,
- )
- EMAIL_LOGIN = ToolTypeDef(
- name="email_login",
- desc="For logging to an email.",
- )
- FEATURE_ENGINEERING = ToolTypeDef(
- name="feature_engineering",
- desc="Only for creating new columns for input data.",
- usage_prompt=FEATURE_ENGINEERING_PROMPT,
- )
- MODEL_TRAIN = ToolTypeDef(
- name="model_train",
- desc="Only for training model.",
- usage_prompt=MODEL_TRAIN_PROMPT,
- )
- MODEL_EVALUATE = ToolTypeDef(
- name="model_evaluate",
- desc="Only for evaluating model.",
- usage_prompt=MODEL_EVALUATE_PROMPT,
- )
- STABLE_DIFFUSION = ToolTypeDef(
- name="stable_diffusion",
- desc="Related to text2image, image2image using stable diffusion model.",
- )
- IMAGE2WEBPAGE = ToolTypeDef(
- name="image2webpage",
- desc="For converting image into webpage code.",
- usage_prompt=IMAGE2WEBPAGE_PROMPT,
- )
- WEBSCRAPING = ToolTypeDef(
- name="web_scraping",
- desc="For scraping data from web pages.",
- )
- OTHER = ToolTypeDef(name="other", desc="Any tools not in the defined categories")
-
- def __missing__(self, key):
- return self.OTHER
-
- @property
- def type_name(self):
- return self.value.name
diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py
index 3b1682291..786715461 100644
--- a/metagpt/tools/web_browser_engine_selenium.py
+++ b/metagpt/tools/web_browser_engine_selenium.py
@@ -103,7 +103,7 @@ class WDMHttpProxyClient(WDMHttpClient):
def get(self, url, **kwargs):
if "proxies" not in kwargs and self.proxy:
- kwargs["proxies"] = {"all_proxy": self.proxy}
+ kwargs["proxies"] = {"all": self.proxy}
return super().get(url, **kwargs)
diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py
index b6296c4fa..e443c3466 100644
--- a/metagpt/utils/common.py
+++ b/metagpt/utils/common.py
@@ -18,6 +18,7 @@ import csv
import importlib
import inspect
import json
+import mimetypes
import os
import platform
import re
@@ -29,6 +30,7 @@ from typing import Any, Callable, List, Literal, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
+import chardet
import loguru
import requests
from PIL import Image
@@ -361,16 +363,6 @@ def parse_recipient(text):
return ""
-def create_func_call_config(func_schema: dict) -> dict:
- """Create new function call config"""
- tools = [{"type": "function", "function": func_schema}]
- tool_choice = {"type": "function", "function": {"name": func_schema["name"]}}
- return {
- "tools": tools,
- "tool_choice": tool_choice,
- }
-
-
def remove_comments(code_str: str) -> str:
"""Remove comments from code."""
pattern = r"(\".*?\"|\'.*?\')|(\#.*?$)"
@@ -673,14 +665,21 @@ def role_raise_decorator(func):
@handle_exception
-async def aread(filename: str | Path, encoding=None) -> str:
+async def aread(filename: str | Path, encoding="utf-8") -> str:
"""Read file asynchronously."""
- async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
- content = await reader.read()
+ try:
+ async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
+ content = await reader.read()
+ except UnicodeDecodeError:
+ async with aiofiles.open(str(filename), mode="rb") as reader:
+ raw = await reader.read()
+ result = chardet.detect(raw)
+ detected_encoding = result["encoding"]
+ content = raw.decode(detected_encoding)
return content
-async def awrite(filename: str | Path, data: str, encoding=None):
+async def awrite(filename: str | Path, data: str, encoding="utf-8"):
"""Write file asynchronously."""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
@@ -775,7 +774,7 @@ def is_coroutine_func(func: Callable) -> bool:
def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) -> list[str]:
- """load mincraft skill from js files"""
+ """load minecraft skill from js files"""
if not skills_dir:
skills_dir = Path(__file__).parent.absolute()
if skill_names is None:
@@ -821,3 +820,21 @@ See FAQ 5.8
"""
)
raise retry_state.outcome.exception()
+
+
+def get_markdown_codeblock_type(filename: str) -> str:
+ """Return the markdown code-block type corresponding to the file extension."""
+ mime_type, _ = mimetypes.guess_type(filename)
+ mappings = {
+ "text/x-shellscript": "bash",
+ "text/x-c++src": "cpp",
+ "text/css": "css",
+ "text/html": "html",
+ "text/x-java": "java",
+ "application/javascript": "javascript",
+ "application/json": "json",
+ "text/x-python": "python",
+ "text/x-ruby": "ruby",
+ "application/sql": "sql",
+ }
+ return mappings.get(mime_type, "text")
diff --git a/metagpt/utils/dependency_file.py b/metagpt/utils/dependency_file.py
index d3add1171..0a375051c 100644
--- a/metagpt/utils/dependency_file.py
+++ b/metagpt/utils/dependency_file.py
@@ -13,9 +13,7 @@ import re
from pathlib import Path
from typing import Set
-import aiofiles
-
-from metagpt.utils.common import aread
+from metagpt.utils.common import aread, awrite
from metagpt.utils.exceptions import handle_exception
@@ -45,8 +43,7 @@ class DependencyFile:
async def save(self):
"""Save dependencies to the file asynchronously."""
data = json.dumps(self._dependencies)
- async with aiofiles.open(str(self._filename), mode="w") as writer:
- await writer.write(data)
+ await awrite(filename=self._filename, data=data)
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
"""Update dependencies for a file asynchronously.
diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py
index d2a06963a..d19f2b705 100644
--- a/metagpt/utils/file_repository.py
+++ b/metagpt/utils/file_repository.py
@@ -14,11 +14,9 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, List, Set
-import aiofiles
-
from metagpt.logs import logger
from metagpt.schema import Document
-from metagpt.utils.common import aread
+from metagpt.utils.common import aread, awrite
from metagpt.utils.json_to_markdown import json_to_markdown
@@ -55,8 +53,7 @@ class FileRepository:
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
content = content if content else "" # avoid `argument must be str, not None` to make it continue
- async with aiofiles.open(str(pathname), mode="w") as writer:
- await writer.write(content)
+ await awrite(filename=str(pathname), data=content)
logger.info(f"save to: {str(pathname)}")
if dependencies is not None:
diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py
index ae3c5118f..e1d140e84 100644
--- a/metagpt/utils/mermaid.py
+++ b/metagpt/utils/mermaid.py
@@ -9,11 +9,9 @@ import asyncio
import os
from pathlib import Path
-import aiofiles
-
from metagpt.config2 import config
from metagpt.logs import logger
-from metagpt.utils.common import check_cmd_exists
+from metagpt.utils.common import awrite, check_cmd_exists
async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
@@ -30,9 +28,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
tmp = Path(f"{output_file_without_suffix}.mmd")
- async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
- await f.write(mermaid_code)
- # tmp.write_text(mermaid_code, encoding="utf-8")
+ await awrite(filename=tmp, data=mermaid_code)
if engine == "nodejs":
if check_cmd_exists(config.mermaid.path) != 0:
diff --git a/metagpt/utils/parse_docstring.py b/metagpt/utils/parse_docstring.py
index e91be8e75..63c0e6890 100644
--- a/metagpt/utils/parse_docstring.py
+++ b/metagpt/utils/parse_docstring.py
@@ -1,45 +1,23 @@
import re
from typing import Tuple
-from pydantic import BaseModel
-
def remove_spaces(text):
return re.sub(r"\s+", " ", text).strip()
-class DocstringParser(BaseModel):
- docstring: str
+class DocstringParser:
+ @staticmethod
+ def parse(docstring: str) -> Tuple[str, str]:
+ """Parse the docstring and return the overall description and the parameter description.
- def parse_desc(self) -> str:
- """Parse and return the description from the docstring."""
-
- def parse_params(self) -> list[Tuple[str, str, str]]:
- """Parse and return the parameters from the docstring.
+ Args:
+ docstring (str): The docstring to be parsed.
Returns:
- list[Tuple[str, str, str]]: A list of input paramter info. Each info is a triple of (param name, param type, param description)
+ Tuple[str, str]: A tuple of (overall description, parameter description)
"""
- def parse_returns(self) -> list[Tuple[str, str]]:
- """Parse and return the output information from the docstring.
-
- Returns:
- list[Tuple[str, str]]: A list of output info. Each info is a tuple of (return type, return description)
- """
-
- @staticmethod
- def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
- """Check if a parameter is optional and return a processed param_type rid of the optionality info if so"""
-
- @staticmethod
- def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
- """Check if a parameter has a default value and return the default value if so"""
-
- @staticmethod
- def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
- """Check if a parameter description includes an enum and return enum values if so"""
-
class reSTDocstringParser(DocstringParser):
"""A parser for reStructuredText (reST) docstring"""
@@ -48,40 +26,18 @@ class reSTDocstringParser(DocstringParser):
class GoogleDocstringParser(DocstringParser):
"""A parser for Google-stype docstring"""
- docstring: str
-
- def parse_desc(self) -> str:
- description_match = re.search(r"^(.*?)(?:Args:|Returns:|Raises:|$)", self.docstring, re.DOTALL)
- description = remove_spaces(description_match.group(1)) if description_match else ""
- return description
-
- def parse_params(self) -> list[Tuple[str, str, str]]:
- args_match = re.search(r"Args:\s*(.*?)(?:Returns:|Raises:|$)", self.docstring, re.DOTALL)
- _args = args_match.group(1).strip() if args_match else ""
- # variable_pattern = re.compile(r"(\w+)\s*\((.*?)\):\s*(.*)")
- variable_pattern = re.compile(
- r"(\w+)\s*\((.*?)\):\s*(.*?)(?=\n\s*\w+\s*\(|\Z)", re.DOTALL
- ) # (?=\n\w+\s*\(|\Z) is to assert that what follows is either the start of the next parameter (indicated by a newline, some word characters, and an opening parenthesis) or the end of the string (\Z).
- params = variable_pattern.findall(_args)
- return params
-
- def parse_returns(self) -> list[Tuple[str, str]]:
- returns_match = re.search(r"Returns:\s*(.*?)(?:Raises:|$)", self.docstring, re.DOTALL)
- returns = returns_match.group(1).strip() if returns_match else ""
- return_pattern = re.compile(r"^(.*)\s*:\s*(.*)$")
- returns = return_pattern.findall(returns)
- return returns
-
@staticmethod
- def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
- return "optional" in param_type, param_type.replace(", optional", "")
+ def parse(docstring: str) -> Tuple[str, str]:
+ if not docstring:
+ return "", ""
- @staticmethod
- def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
- default_val = re.search(r"Defaults to (.+?)\.", param_desc)
- return (True, default_val.group(1)) if default_val else (False, "")
+ docstring = remove_spaces(docstring)
- @staticmethod
- def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
- enum_val = re.search(r"Enum: \[(.+?)\]", param_desc)
- return (True, [e.strip() for e in enum_val.group(1).split(",")]) if enum_val else (False, [])
+ if "Args:" in docstring:
+ overall_desc, param_desc = docstring.split("Args:")
+ param_desc = "Args:" + param_desc
+ else:
+ overall_desc = docstring
+ param_desc = ""
+
+ return overall_desc, param_desc
diff --git a/metagpt/utils/recovery_util.py b/metagpt/utils/recovery_util.py
index d0b197e69..2089ae018 100644
--- a/metagpt/utils/recovery_util.py
+++ b/metagpt/utils/recovery_util.py
@@ -54,5 +54,5 @@ def save_history(role: Role, save_dir: str = ""):
with open(save_path / "plan.json", "w", encoding="utf-8") as plan_file:
json.dump(plan, plan_file, indent=4, ensure_ascii=False)
- save_code_file(name=Path(record_time) / "history_nb", code_context=role.execute_code.nb, file_format="ipynb")
+ save_code_file(name=Path(record_time), code_context=role.execute_code.nb, file_format="ipynb")
return save_path
diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py
index b8756e8c6..17e095c5f 100644
--- a/metagpt/utils/repair_llm_raw_output.py
+++ b/metagpt/utils/repair_llm_raw_output.py
@@ -340,7 +340,9 @@ def extract_state_value_from_output(content: str) -> str:
content (str): llm's output from `Role._think`
"""
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
- pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
+ pattern = (
+ r"(? 0 else "-1"
diff --git a/metagpt/utils/repo_to_markdown.py b/metagpt/utils/repo_to_markdown.py
new file mode 100644
index 000000000..76dfe1b82
--- /dev/null
+++ b/metagpt/utils/repo_to_markdown.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+This file provides functionality to convert a local repository into a markdown representation.
+"""
+from __future__ import annotations
+
+import mimetypes
+from pathlib import Path
+
+from gitignore_parser import parse_gitignore
+
+from metagpt.logs import logger
+from metagpt.utils.common import aread, awrite, get_markdown_codeblock_type, list_files
+from metagpt.utils.tree import tree
+
+
+async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str:
+ """
+ Convert a local repository into a markdown representation.
+
+ This function takes a path to a local repository and generates a markdown representation of the repository structure,
+ including directory trees and file listings.
+
+ Args:
+ repo_path (str | Path): The path to the local repository.
+ output (str | Path, optional): The path to save the generated markdown file. Defaults to None.
+ gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None.
+
+ Returns:
+ str: The markdown representation of the repository.
+ """
+ repo_path = Path(repo_path)
+ gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve()
+
+ markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore)
+
+ gitignore_rules = parse_gitignore(full_path=str(gitignore))
+ markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules)
+
+ if output:
+ await awrite(filename=str(output), data=markdown, encoding="utf-8")
+ return markdown
+
+
+async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str:
+ try:
+ content = tree(repo_path, gitignore, run_command=True)
+ except Exception as e:
+ logger.info(f"{e}, using safe mode.")
+ content = tree(repo_path, gitignore, run_command=False)
+
+ doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n"
+ return doc
+
+
+async def _write_files(repo_path, gitignore_rules) -> str:
+ filenames = list_files(repo_path)
+ markdown = ""
+ for filename in filenames:
+ if gitignore_rules(str(filename)):
+ continue
+ markdown += await _write_file(filename=filename, repo_path=repo_path)
+ return markdown
+
+
+async def _write_file(filename: Path, repo_path: Path) -> str:
+ relative_path = filename.relative_to(repo_path)
+ markdown = f"## {relative_path}\n"
+
+ mime_type, _ = mimetypes.guess_type(filename.name)
+ if "text/" not in mime_type:
+ logger.info(f"Ignore content: {filename}")
+ markdown += "\n---\n\n"
+ return markdown
+ content = await aread(filename, encoding="utf-8")
+ content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
+ code_block_type = get_markdown_codeblock_type(filename.name)
+ markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
+ return markdown
diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py
index f7c53706b..0ba2daa89 100644
--- a/metagpt/utils/token_counter.py
+++ b/metagpt/utils/token_counter.py
@@ -21,6 +21,7 @@ TOKEN_COSTS = {
"gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002},
"gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004},
"gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002},
+ "gpt-3.5-turbo-0125": {"prompt": 0.001, "completion": 0.002},
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
"gpt-4": {"prompt": 0.03, "completion": 0.06},
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
@@ -48,6 +49,8 @@ TOKEN_COSTS = {
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
+ "yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
+ "yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
}
@@ -140,25 +143,24 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
+# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
TOKEN_MAX = {
- "gpt-3.5-turbo": 4096,
- "gpt-3.5-turbo-0301": 4096,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-3.5-turbo-16k-0613": 16384,
- "gpt-35-turbo": 4096,
- "gpt-35-turbo-16k": 16384,
- "gpt-3.5-turbo-1106": 16384,
- "gpt-4-0314": 8192,
- "gpt-4": 8192,
- "gpt-4-32k": 32768,
- "gpt-4-32k-0314": 32768,
- "gpt-4-0613": 8192,
- "gpt-4-turbo-preview": 128000,
"gpt-4-0125-preview": 128000,
+ "gpt-4-turbo-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
+ "gpt-4": 8192,
+ "gpt-4-0613": 8192,
+ "gpt-4-32k": 32768,
+ "gpt-4-32k-0613": 32768,
+ "gpt-3.5-turbo-0125": 16385,
+ "gpt-3.5-turbo": 16385,
+ "gpt-3.5-turbo-1106": 16385,
+ "gpt-3.5-turbo-instruct": 4096,
+ "gpt-3.5-turbo-16k": 16385,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k-0613": 16385,
"text-embedding-ada-002": 8192,
"glm-3-turbo": 128000,
"glm-4": 128000,
@@ -176,10 +178,12 @@ TOKEN_MAX = {
"claude-2.1": 200000,
"claude-3-sonnet-20240229": 200000,
"claude-3-opus-20240229": 200000,
+ "yi-34b-chat-0205": 4000,
+ "yi-34b-chat-200k": 200000,
}
-def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
+def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
@@ -193,6 +197,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-35-turbo-16k",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
+ "gpt-3.5-turbo-0125",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
@@ -209,8 +214,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" == model:
- print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
- return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
+ print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
+ return count_message_tokens(messages, model="gpt-3.5-turbo-0125")
elif "gpt-4" == model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_message_tokens(messages, model="gpt-4-0613")
@@ -224,7 +229,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
- f"See https://github.com/openai/openai-python/blob/main/chatml.md "
+ f"See https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken "
f"for information on how messages are converted to tokens."
)
num_tokens = 0
diff --git a/metagpt/utils/tree.py b/metagpt/utils/tree.py
new file mode 100644
index 000000000..bd7922290
--- /dev/null
+++ b/metagpt/utils/tree.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2024/3/11
+@Author : mashenquan
+@File : tree.py
+@Desc : Implement the same functionality as the `tree` command.
+ Example:
+ >>> print_tree(".")
+ utils
+ +-- serialize.py
+ +-- project_repo.py
+ +-- tree.py
+ +-- mmdc_playwright.py
+ +-- cost_manager.py
+ +-- __pycache__
+ | +-- __init__.cpython-39.pyc
+ | +-- redis.cpython-39.pyc
+ | +-- singleton.cpython-39.pyc
+ | +-- embedding.cpython-39.pyc
+ | +-- make_sk_kernel.cpython-39.pyc
+ | +-- file_repository.cpython-39.pyc
+ +-- file.py
+ +-- save_code.py
+ +-- common.py
+ +-- redis.py
+"""
+from __future__ import annotations
+
+import subprocess
+from pathlib import Path
+from typing import Callable, Dict, List
+
+from gitignore_parser import parse_gitignore
+
+
+def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
+ """
+ Recursively traverses the directory structure and prints it out in a tree-like format.
+
+ Args:
+ root (str or Path): The root directory from which to start traversing.
+ gitignore (str or Path): The filename of gitignore file.
+ run_command (bool): Whether to execute `tree` command. Execute the `tree` command and return the result if True,
+ otherwise execute python code instead.
+
+ Returns:
+ str: A string representation of the directory tree.
+
+ Example:
+ >>> tree(".")
+ utils
+ +-- serialize.py
+ +-- project_repo.py
+ +-- tree.py
+ +-- mmdc_playwright.py
+ +-- __pycache__
+ | +-- __init__.cpython-39.pyc
+ | +-- redis.cpython-39.pyc
+ | +-- singleton.cpython-39.pyc
+ +-- parse_docstring.py
+
+ >>> tree(".", gitignore="../../.gitignore")
+ utils
+ +-- serialize.py
+ +-- project_repo.py
+ +-- tree.py
+ +-- mmdc_playwright.py
+ +-- parse_docstring.py
+
+ >>> tree(".", gitignore="../../.gitignore", run_command=True)
+ utils
+ ├── serialize.py
+ ├── project_repo.py
+ ├── tree.py
+ ├── mmdc_playwright.py
+ └── parse_docstring.py
+
+
+ """
+ root = Path(root).resolve()
+ if run_command:
+ return _execute_tree(root, gitignore)
+
+ git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
+ dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
+ v = _print_tree(dir_)
+ return "\n".join(v)
+
+
+def _list_children(root: Path, git_ignore_rules: Callable) -> Dict[str, Dict]:
+ dir_ = {}
+ for i in root.iterdir():
+ if git_ignore_rules and git_ignore_rules(str(i)):
+ continue
+ try:
+ if i.is_file():
+ dir_[i.name] = {}
+ else:
+ dir_[i.name] = _list_children(root=i, git_ignore_rules=git_ignore_rules)
+ except (FileNotFoundError, PermissionError, OSError):
+ dir_[i.name] = {}
+ return dir_
+
+
+def _print_tree(dir_: Dict[str:Dict]) -> List[str]:
+ ret = []
+ for name, children in dir_.items():
+ ret.append(name)
+ if not children:
+ continue
+ lines = _print_tree(children)
+ for j, v in enumerate(lines):
+ if v[0] not in ["+", " ", "|"]:
+ ret = _add_line(ret)
+ row = f"+-- {v}"
+ else:
+ row = f" {v}"
+ ret.append(row)
+ return ret
+
+
+def _add_line(rows: List[str]) -> List[str]:
+ for i in range(len(rows) - 1, -1, -1):
+ v = rows[i]
+ if v[0] != " ":
+ return rows
+ rows[i] = "|" + v[1:]
+ return rows
+
+
+def _execute_tree(root: Path, gitignore: str | Path) -> str:
+ args = ["--gitfile", str(gitignore)] if gitignore else []
+ try:
+ result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
+ if result.returncode != 0:
+ raise ValueError(f"tree exits with code {result.returncode}")
+ return result.stdout
+ except subprocess.CalledProcessError as e:
+ raise e
diff --git a/requirements.txt b/requirements.txt
index 6e84f4612..fef56e810 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -26,7 +26,7 @@ numpy==1.24.3
openai==1.6.1
openpyxl
beautifulsoup4==4.12.3
-pandas==2.0.3
+pandas==2.1.1
pydantic==2.5.3
#pygame==2.1.3
#pymilvus==2.2.8
@@ -36,7 +36,7 @@ PyYAML==6.0.1
# sentence_transformers==2.2.2
setuptools==65.6.3
tenacity==8.2.3
-tiktoken==0.5.2
+tiktoken==0.6.0
tqdm==4.66.2
#unstructured[local-inference]
# selenium>4
@@ -70,7 +70,7 @@ gitignore-parser==0.1.9
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
websockets~=11.0
networkx~=3.2.1
-google-generativeai==0.3.2
+google-generativeai==0.4.1
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
anytree
ipywidgets==8.1.1
@@ -79,4 +79,4 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
-jieba==0.42.1 # for tool recommendation
+jieba==0.42.1 # for tool recommendation
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 7a14c6182..f834b4c44 100644
--- a/setup.py
+++ b/setup.py
@@ -28,6 +28,16 @@ extras_require = {
"search-google": ["google-api-python-client==2.94.0"],
"search-ddg": ["duckduckgo-search~=4.1.1"],
"ocr": ["paddlepaddle==2.4.2", "paddleocr>=2.0.1", "tabulate==0.9.0"],
+ "rag": [
+ "llama-index-core==0.10.15",
+ "llama-index-embeddings-azure-openai==0.1.6",
+ "llama-index-embeddings-openai==0.1.5",
+ "llama-index-llms-azure-openai==0.1.4",
+ "llama-index-readers-file==0.1.4",
+ "llama-index-retrievers-bm25==0.1.3",
+ "llama-index-vector-stores-faiss==0.1.1",
+ "chromadb==0.4.23",
+ ],
}
extras_require["test"] = [
@@ -42,7 +52,6 @@ extras_require["test"] = [
"connexion[uvicorn]~=3.0.5",
"azure-cognitiveservices-speech~=1.31.0",
"aioboto3~=11.3.0",
- "chromadb==0.4.23",
"gradio==3.0.0",
"grpcio-status==1.48.2",
"pylint==3.0.3",
@@ -57,7 +66,7 @@ extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pr
setup(
name="metagpt",
- version="0.7.4",
+ version="0.7.6",
description="The Multi-Agent Framework",
long_description=long_description,
long_description_content_type="text/markdown",
diff --git a/tests/conftest.py b/tests/conftest.py
index efd782417..8603c752a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -113,12 +113,13 @@ def proxy():
while not reader.at_eof():
writer.write(await reader.read(2048))
writer.close()
+ await writer.wait_closed()
async def handle_client(reader, writer):
data = await reader.readuntil(b"\r\n\r\n")
- print(f"Proxy: {data}") # checking with capfd fixture
infos = pattern.match(data)
host, port = infos.group("host"), infos.group("port")
+ print(f"Proxy: {host}") # checking with capfd fixture
port = int(port) if port else 80
remote_reader, remote_writer = await asyncio.open_connection(host, port)
if data.startswith(b"CONNECT"):
@@ -257,10 +258,10 @@ def http_server():
server = aiohttp.web.Server(handler)
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
- site = aiohttp.web.TCPSite(runner, "localhost", 0)
+ site = aiohttp.web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
- host, port = site._server.sockets[0].getsockname()
- return site, f"http://{host}:{port}"
+ _, port, *_ = site._server.sockets[0].getsockname()
+ return site, f"http://127.0.0.1:{port}"
return start
diff --git a/tests/data/rsp_cache.json b/tests/data/rsp_cache.json
index 8f0567c56..565241779 100644
--- a/tests/data/rsp_cache.json
+++ b/tests/data/rsp_cache.json
@@ -420,5 +420,12 @@
"user: \n## User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n## Context\n\n## Current Plan\n[\n {\n \"task_id\": \"1\",\n \"dependent_task_ids\": [],\n \"instruction\": \"Perform exploratory data analysis on the sklearn Wine recognition dataset including summary statistics and a plot.\",\n \"task_type\": \"eda\",\n \"code\": \"import numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn.datasets import load_wine\\n\\n# Load the wine dataset\\nwine = load_wine()\\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\\n columns=wine['feature_names'] + ['target'])\\n\\n# Summary statistics for numerical features\\nnumerical_summary = wine_df.describe()\\n\\n# Summary statistics for categorical features\\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\\n\\n# Correlation matrix for numerical features\\ncorrelation_matrix = wine_df.corr()\\n\\n# Plotting a histogram for each numerical feature\\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\\nplt.tight_layout()\\nplt.show()\\n\\n# Displaying the summary statistics\\nprint(\\\"Numerical Summary:\\\\n\\\", numerical_summary)\\nprint(\\\"\\\\nCategorical Summary:\\\\n\\\", categorical_summary)\\nprint(\\\"\\\\nCorrelation Matrix:\\\\n\\\", correlation_matrix)\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"2\",\n \"dependent_task_ids\": [\n \"1\"\n ],\n \"instruction\": \"Preprocess the dataset by splitting it into training and validation sets with a 80-20 split.\",\n \"task_type\": \"data preprocessing\",\n \"code\": \"from sklearn.model_selection import train_test_split\\n\\n# Split the data into training and validation sets\\nX_train, X_val, y_train, y_val = train_test_split(\\n wine.data, wine.target, test_size=0.2, random_state=42\\n)\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"3\",\n \"dependent_task_ids\": [\n \"2\"\n ],\n \"instruction\": \"Train a model using the training set to predict wine class.\",\n \"task_type\": \"model train\",\n \"code\": \"from sklearn.ensemble import RandomForestClassifier\\nfrom sklearn.metrics import accuracy_score\\n\\n# Initialize the model\\nmodel = RandomForestClassifier(random_state=42)\\n\\n# Train the model\\nmodel.fit(X_train, y_train)\\n\\n# Predict the validation set\\ny_pred = model.predict(X_val)\\n\\n# Calculate the accuracy\\naccuracy = accuracy_score(y_val, y_pred)\\nprint(f\\\"Validation Accuracy: {accuracy:.2f}\\\")\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"4\",\n \"dependent_task_ids\": [\n \"3\"\n ],\n \"instruction\": \"Evaluate the model on the validation set and show the validation accuracy.\",\n \"task_type\": \"model evaluate\",\n \"code\": \"\",\n \"result\": \"\",\n \"is_success\": false,\n \"is_finished\": false\n }\n]\n## Current Task\n{\"task_id\":\"4\",\"dependent_task_ids\":[\"3\"],\"instruction\":\"Evaluate the model on the validation set and show the validation accuracy.\",\"task_type\":\"model evaluate\",\"code\":\"\",\"result\":\"\",\"is_success\":false,\"is_finished\":false}\n\nassistant: # Since the model has already been trained and predictions have been made on the validation set,\n# we only need to evaluate the model using the predictions and the true labels from the validation set.\n\n# The accuracy has already been calculated and printed in the previous task.\n# However, to follow the task guidance and to ensure the output is only related to the current task,\n# we will calculate and print the validation accuracy again.\n\nvalidation_accuracy = accuracy_score(y_val, y_pred)\nprint(f\"Validation Accuracy: {validation_accuracy:.2f}\")\n\nuser: a successful runThis is a review. Please review output from metagpt.actions.di.execute_nb_code.ExecuteNbCode\nIf you want to change, add, delete a task or merge tasks in the plan, say 'change task task_id or current task, ... (things to change)' If you confirm the output from the current task and wish to continue, type: confirmIf you think user requirement has been fulfilled completedly, you can finish the process by typing: finish\nIf you want to terminate the process, type: exit\nPlease type your review below:\n": "confirm",
"As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function.#SYSTEM_MSG_END#\n# User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n\n# Plan Status\n\n\n# Tool Info\n\n\n# Constraints\n- Take on Current Task if it is in Plan Status, otherwise, tackle User Requirement directly.\n- Ensure the output new code is executable in the same Jupyter notebook as the previous executed code.\n- Always prioritize using pre-defined tools for the same functionality.\n\n# Output\nWhile some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. Output code in the following format:\n```python\nyour code\n```\n#MSG_SEP#Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.": "To achieve the user requirement, we will follow these steps:\n\n1. Load the Wine recognition dataset from `sklearn.datasets`.\n2. Perform a basic data analysis to understand the dataset.\n3. Visualize the data to gain insights.\n4. Split the dataset into training and validation sets.\n5. Train a classification model using the training set.\n6. Evaluate the model on the validation set and report the accuracy.\n\nLet's start by loading the dataset and performing the initial data analysis:\n\n```python\nimport pandas as pd\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import load_wine\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.metrics import accuracy_score\n\n# Load the wine dataset\nwine = load_wine()\nX = wine.data\ny = wine.target\nfeature_names = wine.feature_names\nclass_names = wine.target_names\n\n# Convert to a DataFrame for better visualization and analysis\nwine_df = pd.DataFrame(X, columns=feature_names)\nwine_df['target'] = y\n\n# Basic data analysis\nprint(wine_df.describe())\n\n# Plotting the data\n# Let's plot the distribution of alcohol content for each class\nfor class_idx, class_name in enumerate(class_names):\n plt.hist(wine_df[wine_df['target'] == class_idx]['alcohol'], label=class_name, alpha=0.5)\nplt.title('Alcohol content distribution by wine class')\nplt.xlabel('Alcohol content')\nplt.ylabel('Frequency')\nplt.legend()\nplt.show()\n\n# Split the dataset into training and validation sets (80% train, 20% validation)\nX_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n\n# Train a RandomForestClassifier\nmodel = RandomForestClassifier(random_state=42)\nmodel.fit(X_train, y_train)\n\n# Predict on the validation set\ny_pred = model.predict(X_val)\n\n# Calculate the accuracy on the validation set\naccuracy = accuracy_score(y_val, y_pred)\nprint(f'Validation accuracy: {accuracy:.2f}')\n```\n\nThis code will load the dataset, perform a basic analysis, plot the distribution of alcohol content for each wine class, split the data, train a RandomForestClassifier, and finally, evaluate its performance on the validation set.",
"\n# User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n# Context\n[user: Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy., assistant: import pandas as pd\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import load_wine\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.metrics import accuracy_score\n\n# Load the wine dataset\nwine = load_wine()\nX = wine.data\ny = wine.target\nfeature_names = wine.feature_names\nclass_names = wine.target_names\n\n# Convert to a DataFrame for better visualization and analysis\nwine_df = pd.DataFrame(X, columns=feature_names)\nwine_df['target'] = y\n\n# Basic data analysis\nprint(wine_df.describe())\n\n# Plotting the data\n# Let's plot the distribution of alcohol content for each class\nfor class_idx, class_name in enumerate(class_names):\n plt.hist(wine_df[wine_df['target'] == class_idx]['alcohol'], label=class_name, alpha=0.5)\nplt.title('Alcohol content distribution by wine class')\nplt.xlabel('Alcohol content')\nplt.ylabel('Frequency')\nplt.legend()\nplt.show()\n\n# Split the dataset into training and validation sets (80% train, 20% validation)\nX_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n\n# Train a RandomForestClassifier\nmodel = RandomForestClassifier(random_state=42)\nmodel.fit(X_train, y_train)\n\n# Predict on the validation set\ny_pred = model.predict(X_val)\n\n# Calculate the accuracy on the validation set\naccuracy = accuracy_score(y_val, y_pred)\nprint(f'Validation accuracy: {accuracy:.2f}')\n, user: a successful run]\n\nOutput a json following the format:\n```json\n{\n \"thoughts\": str = \"Thoughts on current situation, reflect on how you should proceed to fulfill the user requirement\",\n \"state\": bool = \"Decide whether you need to take more actions to complete the user requirement. Return true if you think so. Return false if you think the requirement has been completely fulfilled.\"\n}\n```\n": "```json\n{\n \"thoughts\": \"The user has requested a data analysis on the sklearn Wine recognition dataset, including a plot and training a model to predict wine class with validation accuracy. The provided code snippet appears to fulfill the user's requirements. It loads the dataset, performs basic data analysis, creates a plot of alcohol content distribution by wine class, splits the data into training and validation sets, trains a RandomForestClassifier, predicts on the validation set, and prints the validation accuracy. Assuming the code runs successfully without errors, the user's requirements should be met.\",\n \"state\": false\n}\n```",
- "\n## User Requirement:\nconduct feature engineering, add new features on the dataset\n\n## Task\nRecommend up to 5 tools from 'Available Tools' that can help solve the 'User Requirement'. \n\n## Available Tools:\n{'PolynomialExpansion': 'Add polynomial and interaction features from selected numeric columns to input DataFrame.', 'scrape_web_playwright': 'Asynchronously Scrape and save the HTML structure and inner text content of a web page using Playwright. ', 'FillMissingValue': 'Completing missing values with simple strategies.'}\n\n## Tool Selection and Instructions:\n- Select tools most relevant to completing the 'User Requirement'.\n- If you believe that no tools are suitable, indicate with an empty list.\n- Only list the names of the tools, not the full schema of each tool.\n- Ensure selected tools are listed in 'Available Tools'.\n- Output a json list of tool names:\n```json\n[\"tool_name1\", \"tool_name2\", ...]\n```\n": "```json\n[\"PolynomialExpansion\", \"FillMissingValue\"]\n```"
+ "\n## User Requirement:\nconduct feature engineering, add new features on the dataset\n\n## Task\nRecommend up to 5 tools from 'Available Tools' that can help solve the 'User Requirement'. \n\n## Available Tools:\n{'PolynomialExpansion': 'Add polynomial and interaction features from selected numeric columns to input DataFrame.', 'scrape_web_playwright': 'Asynchronously Scrape and save the HTML structure and inner text content of a web page using Playwright. ', 'FillMissingValue': 'Completing missing values with simple strategies.'}\n\n## Tool Selection and Instructions:\n- Select tools most relevant to completing the 'User Requirement'.\n- If you believe that no tools are suitable, indicate with an empty list.\n- Only list the names of the tools, not the full schema of each tool.\n- Ensure selected tools are listed in 'Available Tools'.\n- Output a json list of tool names:\n```json\n[\"tool_name1\", \"tool_name2\", ...]\n```\n": "```json\n[\"PolynomialExpansion\", \"FillMissingValue\"]\n```",
+ "As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function.#SYSTEM_MSG_END#\n# User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n\n# Plan Status\n\n## Finished Tasks\n### code\n```python\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import load_wine\nwine = load_wine()\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\n columns=wine['feature_names'] + ['target'])\nnumerical_summary = wine_df.describe()\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\ncorrelation_matrix = wine_df.corr()\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\nplt.tight_layout()\nplt.show()\nprint(\"Numerical Summary:\\n\", numerical_summary)\nprint(\"\\nCategorical Summary:\\n\", categorical_summary)\nprint(\"\\nCorrelation Matrix:\\n\", correlation_matrix)\n```\n\n### execution result\na successful run\n\n## Current Task\nPreprocess the dataset by splitting it into training and validation sets with a 80-20 split.\n\n## Task Guidance\nWrite complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.\nSpecifically, \nThe current task is about data preprocessing, please note the following:\n- Monitor data types per column, applying appropriate methods.\n- Ensure operations are on existing dataset columns.\n- Avoid writing processed data to files.\n- Avoid any change to label column, such as standardization, etc.\n- Prefer alternatives to one-hot encoding for categorical data.\n- Only encode or scale necessary columns to allow for potential feature-specific engineering tasks (like time_extract, binning, extraction, etc.) later.\n- Each step do data preprocessing to train, must do same for test separately at the same time.\n- Always copy the DataFrame before processing it and use the copy to process.\n\n\n\n# Tool Info\n\n\n# Constraints\n- Take on Current Task if it is in Plan Status, otherwise, tackle User Requirement directly.\n- Ensure the output new code is executable in the same Jupyter notebook as the previous executed code.\n- Always prioritize using pre-defined tools for the same functionality.\n\n# Output\nWhile some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. Output code in the following format:\n```python\nyour code\n```\n#MSG_SEP#\n# Latest Data Info\nLatest data info after previous tasks:\na successful run\n": "```python\nfrom sklearn.model_selection import train_test_split\n\n# Split the data into training and validation sets (80-20 split)\nX_train, X_val, y_train, y_val = train_test_split(\n wine_df.drop('target', axis=1), # features\n wine_df['target'], # target variable\n test_size=0.2, # 20% for validation\n random_state=42 # seed for reproducibility\n)\n\n# Output the shapes of the resulting data splits\nprint(f\"Training set shape: {X_train.shape}\")\nprint(f\"Validation set shape: {X_val.shape}\")\nprint(f\"Training target shape: {y_train.shape}\")\nprint(f\"Validation target shape: {y_val.shape}\")\n```",
+ "\n# Background\nCheck latest data info to guide subsequent tasks.\n\n## Finished Tasks\n```python\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import load_wine\nwine = load_wine()\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\n columns=wine['feature_names'] + ['target'])\nnumerical_summary = wine_df.describe()\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\ncorrelation_matrix = wine_df.corr()\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\nplt.tight_layout()\nplt.show()\nprint(\"Numerical Summary:\\n\", numerical_summary)\nprint(\"\\nCategorical Summary:\\n\", categorical_summary)\nprint(\"\\nCorrelation Matrix:\\n\", correlation_matrix)\n\nfrom sklearn.model_selection import train_test_split\nX_train, X_val, y_train, y_val = train_test_split(\n wine_df.drop('target', axis=1),\n wine_df['target'],\n test_size=0.2,\n random_state=42\n)\nprint(f\"Training set shape: {X_train.shape}\")\nprint(f\"Validation set shape: {X_val.shape}\")\nprint(f\"Training target shape: {y_train.shape}\")\nprint(f\"Validation target shape: {y_val.shape}\")\n```end\n\n# Task\nCheck code in finished tasks, print key variables to guide your following actions.\nSpecifically, if it is a data analysis or machine learning task, print the the latest column information using the following code, with DataFrame variable from 'Finished Tasks' in place of df:\n```python\nfrom metagpt.tools.libs.data_preprocess import get_column_info\n\ncolumn_info = get_column_info(df)\nprint(\"column_info\")\nprint(column_info)\n```end\nOtherwise, print out any key variables you see fit. Return an empty string if you think there is no important data to check.\n\n# Constraints:\n- Your code is to be added to a new cell in jupyter.\n\n# Instruction\nOutput code following the format:\n```python\nyour code\n```\n": "```python\nfrom metagpt.tools.libs.data_preprocess import get_column_info\n\ncolumn_info = get_column_info(wine_df)\nprint(\"column_info\")\nprint(column_info)\n```",
+ "As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function.#SYSTEM_MSG_END#\n# User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n\n# Plan Status\n\n## Finished Tasks\n### code\n```python\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import load_wine\nwine = load_wine()\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\n columns=wine['feature_names'] + ['target'])\nnumerical_summary = wine_df.describe()\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\ncorrelation_matrix = wine_df.corr()\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\nplt.tight_layout()\nplt.show()\nprint(\"Numerical Summary:\\n\", numerical_summary)\nprint(\"\\nCategorical Summary:\\n\", categorical_summary)\nprint(\"\\nCorrelation Matrix:\\n\", correlation_matrix)\n\nfrom sklearn.model_selection import train_test_split\nX_train, X_val, y_train, y_val = train_test_split(\n wine_df.drop('target', axis=1),\n wine_df['target'],\n test_size=0.2,\n random_state=42\n)\nprint(f\"Training set shape: {X_train.shape}\")\nprint(f\"Validation set shape: {X_val.shape}\")\nprint(f\"Training target shape: {y_train.shape}\")\nprint(f\"Validation target shape: {y_val.shape}\")\n```\n\n### execution result\na successful run\n\na successful run\n\n## Current Task\nTrain a model using the training set to predict wine class.\n\n## Task Guidance\nWrite complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.\nSpecifically, \nThe current task is about training a model, please ensure high performance:\n- Keep in mind that your user prioritizes results and is highly focused on model performance. So, when needed, feel free to use models of any complexity to improve effectiveness, such as XGBoost, CatBoost, etc.\n- If non-numeric columns exist, perform label encode together with all steps.\n- Use the data from previous task result directly, do not mock or reload data yourself.\n- Set suitable hyperparameters for the model, make metrics as high as possible.\n\n\n\n# Tool Info\n\n\n# Constraints\n- Take on Current Task if it is in Plan Status, otherwise, tackle User Requirement directly.\n- Ensure the output new code is executable in the same Jupyter notebook as the previous executed code.\n- Always prioritize using pre-defined tools for the same functionality.\n\n# Output\nWhile some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. Output code in the following format:\n```python\nyour code\n```\n#MSG_SEP#\n# Latest Data Info\nLatest data info after previous tasks:\na successful run\n": "```python\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.metrics import accuracy_score\n\n# Initialize the RandomForestClassifier\nmodel = RandomForestClassifier(n_estimators=100, random_state=42)\n\n# Train the model on the training data\nmodel.fit(X_train, y_train)\n\n# Predict the wine classes on the validation set\ny_val_pred = model.predict(X_val)\n\n# Calculate the accuracy on the validation set\nval_accuracy = accuracy_score(y_val, y_val_pred)\n\nprint(f\"Validation Accuracy: {val_accuracy:.4f}\")\n```",
+ "As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function.#SYSTEM_MSG_END#\n# User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n\n# Plan Status\n\n## Finished Tasks\n### code\n```python\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import load_wine\nwine = load_wine()\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\n columns=wine['feature_names'] + ['target'])\nnumerical_summary = wine_df.describe()\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\ncorrelation_matrix = wine_df.corr()\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\nplt.tight_layout()\nplt.show()\nprint(\"Numerical Summary:\\n\", numerical_summary)\nprint(\"\\nCategorical Summary:\\n\", categorical_summary)\nprint(\"\\nCorrelation Matrix:\\n\", correlation_matrix)\n\nfrom sklearn.model_selection import train_test_split\nX_train, X_val, y_train, y_val = train_test_split(\n wine_df.drop('target', axis=1),\n wine_df['target'],\n test_size=0.2,\n random_state=42\n)\nprint(f\"Training set shape: {X_train.shape}\")\nprint(f\"Validation set shape: {X_val.shape}\")\nprint(f\"Training target shape: {y_train.shape}\")\nprint(f\"Validation target shape: {y_val.shape}\")\n\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.metrics import accuracy_score\nmodel = RandomForestClassifier(n_estimators=100, random_state=42)\nmodel.fit(X_train, y_train)\ny_val_pred = model.predict(X_val)\nval_accuracy = accuracy_score(y_val, y_val_pred)\nprint(f\"Validation Accuracy: {val_accuracy:.4f}\")\n```\n\n### execution result\na successful run\n\na successful run\n\na successful run\n\n## Current Task\nEvaluate the model on the validation set and show the validation accuracy.\n\n## Task Guidance\nWrite complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.\nSpecifically, \nThe current task is about evaluating a model, please note the following:\n- Ensure that the evaluated data is same processed as the training data. If not, remember use object in 'Done Tasks' to transform the data.\n- Use trained model from previous task result directly, do not mock or reload model yourself.\n\n\n\n# Tool Info\n\n\n# Constraints\n- Take on Current Task if it is in Plan Status, otherwise, tackle User Requirement directly.\n- Ensure the output new code is executable in the same Jupyter notebook as the previous executed code.\n- Always prioritize using pre-defined tools for the same functionality.\n\n# Output\nWhile some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. Output code in the following format:\n```python\nyour code\n```\n": "```python\n# Evaluate the model on the validation set and show the validation accuracy\nval_accuracy = accuracy_score(y_val, y_val_pred)\nprint(f\"Validation Accuracy: {val_accuracy:.4f}\")\n```",
+ "user: \n## User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n## Context\n\n## Current Plan\n[\n {\n \"task_id\": \"1\",\n \"dependent_task_ids\": [],\n \"instruction\": \"Perform exploratory data analysis on the sklearn Wine recognition dataset including summary statistics and a plot.\",\n \"task_type\": \"eda\",\n \"code\": \"import numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn.datasets import load_wine\\n\\n# Load the wine dataset\\nwine = load_wine()\\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\\n columns=wine['feature_names'] + ['target'])\\n\\n# Summary statistics for numerical features\\nnumerical_summary = wine_df.describe()\\n\\n# Summary statistics for categorical features\\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\\n\\n# Correlation matrix for numerical features\\ncorrelation_matrix = wine_df.corr()\\n\\n# Plotting a histogram for each numerical feature\\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\\nplt.tight_layout()\\nplt.show()\\n\\n# Displaying the summary statistics\\nprint(\\\"Numerical Summary:\\\\n\\\", numerical_summary)\\nprint(\\\"\\\\nCategorical Summary:\\\\n\\\", categorical_summary)\\nprint(\\\"\\\\nCorrelation Matrix:\\\\n\\\", correlation_matrix)\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"2\",\n \"dependent_task_ids\": [\n \"1\"\n ],\n \"instruction\": \"Preprocess the dataset by splitting it into training and validation sets with a 80-20 split.\",\n \"task_type\": \"data preprocessing\",\n \"code\": \"\",\n \"result\": \"\",\n \"is_success\": false,\n \"is_finished\": false\n },\n {\n \"task_id\": \"3\",\n \"dependent_task_ids\": [\n \"2\"\n ],\n \"instruction\": \"Train a model using the training set to predict wine class.\",\n \"task_type\": \"model train\",\n \"code\": \"\",\n \"result\": \"\",\n \"is_success\": false,\n \"is_finished\": false\n },\n {\n \"task_id\": \"4\",\n \"dependent_task_ids\": [\n \"3\"\n ],\n \"instruction\": \"Evaluate the model on the validation set and show the validation accuracy.\",\n \"task_type\": \"model evaluate\",\n \"code\": \"\",\n \"result\": \"\",\n \"is_success\": false,\n \"is_finished\": false\n }\n]\n## Current Task\n{\"task_id\":\"2\",\"dependent_task_ids\":[\"1\"],\"instruction\":\"Preprocess the dataset by splitting it into training and validation sets with a 80-20 split.\",\"task_type\":\"data preprocessing\",\"code\":\"\",\"result\":\"\",\"is_success\":false,\"is_finished\":false}\n\nuser: \n# Latest Data Info\nLatest data info after previous tasks:\na successful run\n\nassistant: from sklearn.model_selection import train_test_split\n\n# Split the data into training and validation sets (80-20 split)\nX_train, X_val, y_train, y_val = train_test_split(\n wine_df.drop('target', axis=1), # features\n wine_df['target'], # target variable\n test_size=0.2, # 20% for validation\n random_state=42 # seed for reproducibility\n)\n\n# Output the shapes of the resulting data splits\nprint(f\"Training set shape: {X_train.shape}\")\nprint(f\"Validation set shape: {X_val.shape}\")\nprint(f\"Training target shape: {y_train.shape}\")\nprint(f\"Validation target shape: {y_val.shape}\")\n\nuser: a successful runThis is a review. Please review output from metagpt.actions.di.execute_nb_code.ExecuteNbCode\nIf you want to change, add, delete a task or merge tasks in the plan, say 'change task task_id or current task, ... (things to change)' If you confirm the output from the current task and wish to continue, type: confirmIf you think user requirement has been fulfilled completedly, you can finish the process by typing: finish\nIf you want to terminate the process, type: exit\nPlease type your review below:\n": "confirm",
+ "user: \n## User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n## Context\n\n## Current Plan\n[\n {\n \"task_id\": \"1\",\n \"dependent_task_ids\": [],\n \"instruction\": \"Perform exploratory data analysis on the sklearn Wine recognition dataset including summary statistics and a plot.\",\n \"task_type\": \"eda\",\n \"code\": \"import numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn.datasets import load_wine\\n\\n# Load the wine dataset\\nwine = load_wine()\\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\\n columns=wine['feature_names'] + ['target'])\\n\\n# Summary statistics for numerical features\\nnumerical_summary = wine_df.describe()\\n\\n# Summary statistics for categorical features\\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\\n\\n# Correlation matrix for numerical features\\ncorrelation_matrix = wine_df.corr()\\n\\n# Plotting a histogram for each numerical feature\\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\\nplt.tight_layout()\\nplt.show()\\n\\n# Displaying the summary statistics\\nprint(\\\"Numerical Summary:\\\\n\\\", numerical_summary)\\nprint(\\\"\\\\nCategorical Summary:\\\\n\\\", categorical_summary)\\nprint(\\\"\\\\nCorrelation Matrix:\\\\n\\\", correlation_matrix)\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"2\",\n \"dependent_task_ids\": [\n \"1\"\n ],\n \"instruction\": \"Preprocess the dataset by splitting it into training and validation sets with a 80-20 split.\",\n \"task_type\": \"data preprocessing\",\n \"code\": \"from sklearn.model_selection import train_test_split\\n\\n# Split the data into training and validation sets (80-20 split)\\nX_train, X_val, y_train, y_val = train_test_split(\\n wine_df.drop('target', axis=1), # features\\n wine_df['target'], # target variable\\n test_size=0.2, # 20% for validation\\n random_state=42 # seed for reproducibility\\n)\\n\\n# Output the shapes of the resulting data splits\\nprint(f\\\"Training set shape: {X_train.shape}\\\")\\nprint(f\\\"Validation set shape: {X_val.shape}\\\")\\nprint(f\\\"Training target shape: {y_train.shape}\\\")\\nprint(f\\\"Validation target shape: {y_val.shape}\\\")\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"3\",\n \"dependent_task_ids\": [\n \"2\"\n ],\n \"instruction\": \"Train a model using the training set to predict wine class.\",\n \"task_type\": \"model train\",\n \"code\": \"\",\n \"result\": \"\",\n \"is_success\": false,\n \"is_finished\": false\n },\n {\n \"task_id\": \"4\",\n \"dependent_task_ids\": [\n \"3\"\n ],\n \"instruction\": \"Evaluate the model on the validation set and show the validation accuracy.\",\n \"task_type\": \"model evaluate\",\n \"code\": \"\",\n \"result\": \"\",\n \"is_success\": false,\n \"is_finished\": false\n }\n]\n## Current Task\n{\"task_id\":\"3\",\"dependent_task_ids\":[\"2\"],\"instruction\":\"Train a model using the training set to predict wine class.\",\"task_type\":\"model train\",\"code\":\"\",\"result\":\"\",\"is_success\":false,\"is_finished\":false}\n\nuser: \n# Latest Data Info\nLatest data info after previous tasks:\na successful run\n\nassistant: from sklearn.ensemble import RandomForestClassifier\nfrom sklearn.metrics import accuracy_score\n\n# Initialize the RandomForestClassifier\nmodel = RandomForestClassifier(n_estimators=100, random_state=42)\n\n# Train the model on the training data\nmodel.fit(X_train, y_train)\n\n# Predict the wine classes on the validation set\ny_val_pred = model.predict(X_val)\n\n# Calculate the accuracy on the validation set\nval_accuracy = accuracy_score(y_val, y_val_pred)\n\nprint(f\"Validation Accuracy: {val_accuracy:.4f}\")\n\nuser: a successful runThis is a review. Please review output from metagpt.actions.di.execute_nb_code.ExecuteNbCode\nIf you want to change, add, delete a task or merge tasks in the plan, say 'change task task_id or current task, ... (things to change)' If you confirm the output from the current task and wish to continue, type: confirmIf you think user requirement has been fulfilled completedly, you can finish the process by typing: finish\nIf you want to terminate the process, type: exit\nPlease type your review below:\n": "confirm",
+ "user: \n## User Requirement\nRun data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy.\n## Context\n\n## Current Plan\n[\n {\n \"task_id\": \"1\",\n \"dependent_task_ids\": [],\n \"instruction\": \"Perform exploratory data analysis on the sklearn Wine recognition dataset including summary statistics and a plot.\",\n \"task_type\": \"eda\",\n \"code\": \"import numpy as np\\nimport pandas as pd\\nimport matplotlib.pyplot as plt\\nfrom sklearn.datasets import load_wine\\n\\n# Load the wine dataset\\nwine = load_wine()\\nwine_df = pd.DataFrame(data=np.c_[wine['data'], wine['target']],\\n columns=wine['feature_names'] + ['target'])\\n\\n# Summary statistics for numerical features\\nnumerical_summary = wine_df.describe()\\n\\n# Summary statistics for categorical features\\ncategorical_summary = wine_df.select_dtypes(include=['object', 'category']).describe()\\n\\n# Correlation matrix for numerical features\\ncorrelation_matrix = wine_df.corr()\\n\\n# Plotting a histogram for each numerical feature\\nwine_df.hist(bins=15, figsize=(15, 10), layout=(4, 4))\\nplt.tight_layout()\\nplt.show()\\n\\n# Displaying the summary statistics\\nprint(\\\"Numerical Summary:\\\\n\\\", numerical_summary)\\nprint(\\\"\\\\nCategorical Summary:\\\\n\\\", categorical_summary)\\nprint(\\\"\\\\nCorrelation Matrix:\\\\n\\\", correlation_matrix)\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"2\",\n \"dependent_task_ids\": [\n \"1\"\n ],\n \"instruction\": \"Preprocess the dataset by splitting it into training and validation sets with a 80-20 split.\",\n \"task_type\": \"data preprocessing\",\n \"code\": \"from sklearn.model_selection import train_test_split\\n\\n# Split the data into training and validation sets (80-20 split)\\nX_train, X_val, y_train, y_val = train_test_split(\\n wine_df.drop('target', axis=1), # features\\n wine_df['target'], # target variable\\n test_size=0.2, # 20% for validation\\n random_state=42 # seed for reproducibility\\n)\\n\\n# Output the shapes of the resulting data splits\\nprint(f\\\"Training set shape: {X_train.shape}\\\")\\nprint(f\\\"Validation set shape: {X_val.shape}\\\")\\nprint(f\\\"Training target shape: {y_train.shape}\\\")\\nprint(f\\\"Validation target shape: {y_val.shape}\\\")\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"3\",\n \"dependent_task_ids\": [\n \"2\"\n ],\n \"instruction\": \"Train a model using the training set to predict wine class.\",\n \"task_type\": \"model train\",\n \"code\": \"from sklearn.ensemble import RandomForestClassifier\\nfrom sklearn.metrics import accuracy_score\\n\\n# Initialize the RandomForestClassifier\\nmodel = RandomForestClassifier(n_estimators=100, random_state=42)\\n\\n# Train the model on the training data\\nmodel.fit(X_train, y_train)\\n\\n# Predict the wine classes on the validation set\\ny_val_pred = model.predict(X_val)\\n\\n# Calculate the accuracy on the validation set\\nval_accuracy = accuracy_score(y_val, y_val_pred)\\n\\nprint(f\\\"Validation Accuracy: {val_accuracy:.4f}\\\")\\n\",\n \"result\": \"a successful run\",\n \"is_success\": true,\n \"is_finished\": true\n },\n {\n \"task_id\": \"4\",\n \"dependent_task_ids\": [\n \"3\"\n ],\n \"instruction\": \"Evaluate the model on the validation set and show the validation accuracy.\",\n \"task_type\": \"model evaluate\",\n \"code\": \"\",\n \"result\": \"\",\n \"is_success\": false,\n \"is_finished\": false\n }\n]\n## Current Task\n{\"task_id\":\"4\",\"dependent_task_ids\":[\"3\"],\"instruction\":\"Evaluate the model on the validation set and show the validation accuracy.\",\"task_type\":\"model evaluate\",\"code\":\"\",\"result\":\"\",\"is_success\":false,\"is_finished\":false}\n\nassistant: # Evaluate the model on the validation set and show the validation accuracy\nval_accuracy = accuracy_score(y_val, y_val_pred)\nprint(f\"Validation Accuracy: {val_accuracy:.4f}\")\n\nuser: a successful runThis is a review. Please review output from metagpt.actions.di.execute_nb_code.ExecuteNbCode\nIf you want to change, add, delete a task or merge tasks in the plan, say 'change task task_id or current task, ... (things to change)' If you confirm the output from the current task and wish to continue, type: confirmIf you think user requirement has been fulfilled completedly, you can finish the process by typing: finish\nIf you want to terminate the process, type: exit\nPlease type your review below:\n": "confirm"
}
\ No newline at end of file
diff --git a/tests/metagpt/actions/di/test_debug_code.py b/tests/metagpt/actions/di/test_debug_code.py
deleted file mode 100644
index 67f72ad63..000000000
--- a/tests/metagpt/actions/di/test_debug_code.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# -*- coding: utf-8 -*-
-# @Date : 1/11/2024 8:51 PM
-# @Author : stellahong (stellahong@fuzhi.ai)
-# @Desc :
-
-import pytest
-
-from metagpt.actions.di.debug_code import DebugCode
-from metagpt.schema import Message
-
-ErrorStr = """Tested passed:
-
-Tests failed:
-assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
-"""
-
-CODE = """
-def sort_array(arr):
- # Helper function to count the number of ones in the binary representation
- def count_ones(n):
- return bin(n).count('1')
-
- # Sort the array using a custom key function
- # The key function returns a tuple (number of ones, value) for each element
- # This ensures that if two elements have the same number of ones, they are sorted by their value
- sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
-
- return sorted_arr
-```
-"""
-
-DebugContext = '''Solve the problem in Python:
-def sort_array(arr):
- """
- In this Kata, you have to sort an array of non-negative integers according to
- number of ones in their binary representation in ascending order.
- For similar number of ones, sort based on decimal value.
-
- It must be implemented like this:
- >>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
- >>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
- >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
- """
-'''
-
-
-@pytest.mark.asyncio
-async def test_debug_code():
- debug_context = Message(content=DebugContext)
- new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
- assert "def sort_array(arr)" in new_code["code"]
diff --git a/tests/metagpt/actions/di/test_execute_nb_code.py b/tests/metagpt/actions/di/test_execute_nb_code.py
index b491dd212..b206046d7 100644
--- a/tests/metagpt/actions/di/test_execute_nb_code.py
+++ b/tests/metagpt/actions/di/test_execute_nb_code.py
@@ -1,6 +1,6 @@
import pytest
-from metagpt.actions.di.execute_nb_code import ExecuteNbCode, truncate
+from metagpt.actions.di.execute_nb_code import ExecuteNbCode
@pytest.mark.asyncio
@@ -8,6 +8,7 @@ async def test_code_running():
executor = ExecuteNbCode()
output, is_success = await executor.run("print('hello world!')")
assert is_success
+ await executor.terminate()
@pytest.mark.asyncio
@@ -17,6 +18,7 @@ async def test_split_code_running():
_ = await executor.run("z=x+y")
output, is_success = await executor.run("assert z==3")
assert is_success
+ await executor.terminate()
@pytest.mark.asyncio
@@ -24,6 +26,7 @@ async def test_execute_error():
executor = ExecuteNbCode()
output, is_success = await executor.run("z=1/0")
assert not is_success
+ await executor.terminate()
PLOT_CODE = """
@@ -52,21 +55,7 @@ async def test_plotting_code():
executor = ExecuteNbCode()
output, is_success = await executor.run(PLOT_CODE)
assert is_success
-
-
-def test_truncate():
- # 代码执行成功
- output, is_success = truncate("hello world", 5, True)
- assert "Truncated to show only first 5 characters\nhello" in output
- assert is_success
- # 代码执行失败
- output, is_success = truncate("hello world", 5, False)
- assert "Truncated to show only last 5 characters\nworld" in output
- assert not is_success
- # 异步
- output, is_success = truncate("1 to test terminate under continuous run
+async def test_terminate(k):
+ for _ in range(k):
+ executor = ExecuteNbCode()
+ await executor.run(code='print("This is a code!")', language="python")
+ is_kernel_alive = await executor.nb_client.km.is_alive()
+ assert is_kernel_alive
+ await executor.terminate()
+ assert executor.nb_client.km is None
+ assert executor.nb_client.kc is None
@pytest.mark.asyncio
@@ -114,3 +107,22 @@ async def test_reset():
assert is_kernel_alive
await executor.reset()
assert executor.nb_client.km is None
+ await executor.terminate()
+
+
+@pytest.mark.asyncio
+async def test_parse_outputs():
+ executor = ExecuteNbCode()
+ code = """
+ import pandas as pd
+ df = pd.DataFrame({'ID': [1,2,3], 'NAME': ['a', 'b', 'c']})
+ print(df.columns)
+ print(f"columns num:{len(df.columns)}")
+ print(df['DUMMPY_ID'])
+ """
+ output, is_success = await executor.run(code)
+ assert not is_success
+ assert "Index(['ID', 'NAME'], dtype='object')" in output
+ assert "KeyError: 'DUMMPY_ID'" in output
+ assert "columns num:2" in output
+ await executor.terminate()
diff --git a/tests/metagpt/actions/di/test_ml_action.py b/tests/metagpt/actions/di/test_ml_action.py
deleted file mode 100644
index 826a7fcf2..000000000
--- a/tests/metagpt/actions/di/test_ml_action.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import pytest
-
-from metagpt.actions.di.ml_action import WriteCodeWithToolsML
-from metagpt.schema import Plan, Task
-
-
-@pytest.mark.asyncio
-async def test_write_code_with_tools():
- write_code_ml = WriteCodeWithToolsML()
-
- task_map = {
- "1": Task(
- task_id="1",
- instruction="随机生成一个pandas DataFrame数据集",
- task_type="other",
- dependent_task_ids=[],
- code="""
- import pandas as pd
- df = pd.DataFrame({
- 'a': [1, 2, 3, 4, 5],
- 'b': [1.1, 2.2, 3.3, 4.4, np.nan],
- 'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
- 'd': [1, 2, 3, 4, 5]
- })
- """,
- is_finished=True,
- ),
- "2": Task(
- task_id="2",
- instruction="对数据集进行数据清洗",
- task_type="data_preprocess",
- dependent_task_ids=["1"],
- ),
- }
- plan = Plan(
- goal="构造数据集并进行数据清洗",
- tasks=list(task_map.values()),
- task_map=task_map,
- current_task_id="2",
- )
- column_info = ""
-
- _, code_with_ml = await write_code_ml.run([], plan, column_info)
- code_with_ml = code_with_ml["code"]
- assert len(code_with_ml) > 0
- print(code_with_ml)
diff --git a/tests/metagpt/actions/di/test_write_analysis_code.py b/tests/metagpt/actions/di/test_write_analysis_code.py
index d6bbae1c6..2996f31f7 100644
--- a/tests/metagpt/actions/di/test_write_analysis_code.py
+++ b/tests/metagpt/actions/di/test_write_analysis_code.py
@@ -1,134 +1,61 @@
-import asyncio
-
import pytest
-from metagpt.actions.di.execute_nb_code import ExecuteNbCode
-from metagpt.actions.di.write_analysis_code import (
- WriteCodeWithoutTools,
- WriteCodeWithTools,
-)
-from metagpt.logs import logger
-from metagpt.schema import Message, Plan, Task
-from metagpt.strategy.planner import STRUCTURAL_CONTEXT
-
-
-@pytest.mark.skip
-@pytest.mark.asyncio
-async def test_write_code_by_list_plan():
- write_code = WriteCodeWithoutTools()
- execute_code = ExecuteNbCode()
- messages = []
- plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "回顾已完成的任务", "求均值", "总结"]
- for task in plan:
- print(f"\n任务: {task}\n\n")
- messages.append(Message(task, role="assistant"))
- code = await write_code.run(messages)
- if task.startswith(("回顾", "总结")):
- assert code["language"] == "markdown"
- else:
- assert code["language"] == "python"
- messages.append(Message(code["code"], role="assistant"))
- assert len(code) > 0
- output, _ = await execute_code.run(**code)
- print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
- messages.append(output)
+from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
+from metagpt.schema import Message
@pytest.mark.asyncio
-async def test_tool_recommendation():
- task = "clean and preprocess the data"
- available_tools = {
- "FillMissingValue": "Filling missing values",
- "SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
- }
- write_code = WriteCodeWithTools()
- tools = await write_code._recommend_tool(task, available_tools)
+async def test_write_code_with_plan():
+ write_code = WriteAnalysisCode()
- assert len(tools) == 1
- assert "FillMissingValue" in tools
+ user_requirement = "Run data analysis on sklearn Iris dataset, include a plot"
+ plan_status = "\n## Finished Tasks\n### code\n```python\n\n```\n\n### execution result\n\n\n## Current Task\nLoad the sklearn Iris dataset and perform exploratory data analysis\n\n## Task Guidance\nWrite complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.\nSpecifically, \nThe current task is about exploratory data analysis, please note the following:\n- Distinguish column types with `select_dtypes` for tailored analysis and visualization, such as correlation.\n- Remember to `import numpy as np` before using Numpy functions.\n\n"
+
+ code = await write_code.run(user_requirement=user_requirement, plan_status=plan_status)
+ assert len(code) > 0
+ assert "sklearn" in code
@pytest.mark.asyncio
async def test_write_code_with_tools():
- write_code = WriteCodeWithTools()
+ write_code = WriteAnalysisCode()
- requirement = "构造数据集并进行数据清洗"
- task_map = {
- "1": Task(
- task_id="1",
- instruction="随机生成一个pandas DataFrame数据集",
- task_type="other",
- dependent_task_ids=[],
- code="""
- import pandas as pd
- df = pd.DataFrame({
- 'a': [1, 2, 3, 4, 5],
- 'b': [1.1, 2.2, 3.3, 4.4, np.nan],
- 'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
- 'd': [1, 2, 3, 4, 5]
- })
- """,
- is_finished=True,
- ),
- "2": Task(
- task_id="2",
- instruction="对数据集进行数据清洗",
- task_type="data_preprocess",
- dependent_task_ids=["1"],
- ),
- }
- plan = Plan(
- goal="构造数据集并进行数据清洗",
- tasks=list(task_map.values()),
- task_map=task_map,
- current_task_id="2",
- )
+ user_requirement = "Preprocess sklearn Wine recognition dataset and train a model to predict wine class (20% as validation), and show validation accuracy."
+ tool_info = """
+ ## Capabilities
+ - You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python class or function.
+ - You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
- context = STRUCTURAL_CONTEXT.format(
- user_requirement=requirement,
- context=plan.context,
- tasks=list(task_map.values()),
- current_task=plan.current_task.model_dump_json(),
- )
- context_msg = [Message(content=context, role="user")]
+ ## Available Tools:
+ Each tool is described in JSON format. When you call a tool, import the tool from its path first.
+ {'FillMissingValue': {'type': 'class', 'description': 'Completing missing values with simple strategies.', 'methods': {'__init__': {'type': 'function', 'description': 'Initialize self. ', 'signature': '(self, features: \'list\', strategy: "Literal[\'mean\', \'median\', \'most_frequent\', \'constant\']" = \'mean\', fill_value=None)', 'parameters': 'Args: features (list): Columns to be processed. strategy (Literal["mean", "median", "most_frequent", "constant"], optional): The imputation strategy, notice \'mean\' and \'median\' can only be used for numeric features. Defaults to \'mean\'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.'}, 'fit': {'type': 'function', 'description': 'Fit a model to be used in subsequent transform. ', 'signature': "(self, df: 'pd.DataFrame')", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame.'}, 'fit_transform': {'type': 'function', 'description': 'Fit and transform the input DataFrame. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}, 'transform': {'type': 'function', 'description': 'Transform the input DataFrame with the fitted model. ', 'signature': "(self, df: 'pd.DataFrame') -> 'pd.DataFrame'", 'parameters': 'Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.'}}, 'tool_path': 'metagpt/tools/libs/data_preprocess.py'}
+ """
- code = await write_code.run(context_msg, plan)
- code = code["code"]
+ code = await write_code.run(user_requirement=user_requirement, tool_info=tool_info)
assert len(code) > 0
- print(code)
+ assert "metagpt.tools.libs" in code
@pytest.mark.asyncio
-async def test_write_code_to_correct_error():
- structural_context = """
- ## User Requirement
- read a dataset test.csv and print its head
- ## Current Plan
- [
- {
- "task_id": "1",
- "dependent_task_ids": [],
- "instruction": "import pandas and load the dataset from 'test.csv'.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- },
- {
- "task_id": "2",
- "dependent_task_ids": [
- "1"
- ],
- "instruction": "Print the head of the dataset to display the first few rows.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- }
- ]
+async def test_debug_with_reflection():
+ user_requirement = "read a dataset test.csv and print its head"
+
+ plan_status = """
+ ## Finished Tasks
+ ### code
+ ```python
+ ```
+
+ ### execution result
+
## Current Task
- {"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false}
+ import pandas and load the dataset from 'test.csv'.
+
+ ## Task Guidance
+ Write complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.
+ Specifically,
"""
+
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
error = """
Traceback (most recent call last):
@@ -139,186 +66,14 @@ async def test_write_code_to_correct_error():
raise ValueError(
ValueError: Excel file format cannot be determined, you must specify an engine manually.
"""
- context = [
- Message(content=structural_context, role="user"),
+ working_memory = [
Message(content=wrong_code, role="assistant"),
Message(content=error, role="user"),
]
- new_code = await WriteCodeWithoutTools().run(context=context)
- new_code = new_code["code"]
- print(new_code)
+ new_code = await WriteAnalysisCode().run(
+ user_requirement=user_requirement,
+ plan_status=plan_status,
+ working_memory=working_memory,
+ use_reflection=True,
+ )
assert "read_csv" in new_code # should correct read_excel to read_csv
-
-
-@pytest.mark.asyncio
-async def test_write_code_reuse_code_simple():
- structural_context = """
- ## User Requirement
- read a dataset test.csv and print its head
- ## Current Plan
- [
- {
- "task_id": "1",
- "dependent_task_ids": [],
- "instruction": "import pandas and load the dataset from 'test.csv'.",
- "task_type": "",
- "code": "import pandas as pd\ndata = pd.read_csv('test.csv')",
- "result": "",
- "is_finished": true
- },
- {
- "task_id": "2",
- "dependent_task_ids": [
- "1"
- ],
- "instruction": "Print the head of the dataset to display the first few rows.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- }
- ]
- ## Current Task
- {"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false}
- """
- context = [
- Message(content=structural_context, role="user"),
- ]
- code = await WriteCodeWithoutTools().run(context=context)
- code = code["code"]
- print(code)
- assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
-
-
-@pytest.mark.skip
-@pytest.mark.asyncio
-async def test_write_code_reuse_code_long():
- """test code reuse for long context"""
-
- structural_context = """
- ## User Requirement
- Run data analysis on sklearn Iris dataset, include a plot
- ## Current Plan
- [
- {
- "task_id": "1",
- "dependent_task_ids": [],
- "instruction": "Load the Iris dataset from sklearn.",
- "task_type": "",
- "code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]",
- "result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))",
- "is_finished": true
- },
- {
- "task_id": "2",
- "dependent_task_ids": [
- "1"
- ],
- "instruction": "Perform exploratory data analysis on the Iris dataset.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- },
- {
- "task_id": "3",
- "dependent_task_ids": [
- "2"
- ],
- "instruction": "Create a plot visualizing the Iris dataset features.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- }
- ]
- ## Current Task
- {"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
- """
- context = [
- Message(content=structural_context, role="user"),
- ]
- trials_num = 5
- trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
- trial_results = await asyncio.gather(*trials)
- print(*trial_results, sep="\n\n***\n\n")
- success = [
- "load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results
- ] # should reuse iris_data from previous tasks
- success_rate = sum(success) / trials_num
- logger.info(f"success rate: {success_rate :.2f}")
- assert success_rate >= 0.8
-
-
-@pytest.mark.skip
-@pytest.mark.asyncio
-async def test_write_code_reuse_code_long_for_wine():
- """test code reuse for long context"""
-
- structural_context = """
- ## User Requirement
- Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy
- ## Current Plan
- [
- {
- "task_id": "1",
- "dependent_task_ids": [],
- "instruction": "Load the sklearn Wine recognition dataset and perform exploratory data analysis."
- "task_type": "",
- "code": "from sklearn.datasets import load_wine\n# Load the Wine recognition dataset\nwine_data = load_wine()\n# Perform exploratory data analysis\nwine_data.keys()",
- "result": "Truncated to show only the last 1000 characters\ndict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names'])",
- "is_finished": true
- },
- {
- "task_id": "2",
- "dependent_task_ids": ["1"],
- "instruction": "Create a plot to visualize some aspect of the wine dataset."
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- },
- {
- "task_id": "3",
- "dependent_task_ids": ["1"],
- "instruction": "Split the dataset into training and validation sets with a 20% validation size.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- },
- {
- "task_id": "4",
- "dependent_task_ids": ["3"],
- "instruction": "Train a model on the training set to predict wine class.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- },
- {
- "task_id": "5",
- "dependent_task_ids": ["4"],
- "instruction": "Evaluate the model on the validation set and report the accuracy.",
- "task_type": "",
- "code": "",
- "result": "",
- "is_finished": false
- }
- ]
- ## Current Task
- {"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Create a plot to visualize some aspect of the Wine dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
- """
- context = [
- Message(content=structural_context, role="user"),
- ]
- trials_num = 5
- trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
- trial_results = await asyncio.gather(*trials)
- print(*trial_results, sep="\n\n***\n\n")
- success = [
- "load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results
- ] # should reuse iris_data from previous tasks
- success_rate = sum(success) / trials_num
- logger.info(f"success rate: {success_rate :.2f}")
- assert success_rate >= 0.8
diff --git a/tests/metagpt/actions/di/test_write_plan.py b/tests/metagpt/actions/di/test_write_plan.py
index 80b3399b8..cad0c8a71 100644
--- a/tests/metagpt/actions/di/test_write_plan.py
+++ b/tests/metagpt/actions/di/test_write_plan.py
@@ -23,12 +23,10 @@ def test_precheck_update_plan_from_rsp():
@pytest.mark.asyncio
-@pytest.mark.parametrize("use_tools", [(False), (True)])
-async def test_write_plan(use_tools):
+async def test_write_plan():
rsp = await WritePlan().run(
- context=[Message("run analysis on sklearn iris dataset", role="user")], use_tools=use_tools
+ context=[Message("Run data analysis on sklearn Iris dataset, include a plot", role="user")]
)
assert "task_id" in rsp
assert "instruction" in rsp
- assert "json" not in rsp # the output should be the content inside ```json ```
diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py
index 4414c2000..3731cd598 100644
--- a/tests/metagpt/actions/test_rebuild_class_view.py
+++ b/tests/metagpt/actions/test_rebuild_class_view.py
@@ -14,7 +14,6 @@ from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.llm import LLM
-@pytest.mark.skip
@pytest.mark.asyncio
async def test_rebuild(context):
action = RebuildClassView(
diff --git a/tests/metagpt/actions/test_rebuild_sequence_view.py b/tests/metagpt/actions/test_rebuild_sequence_view.py
index 1daea22a4..9be3e8a99 100644
--- a/tests/metagpt/actions/test_rebuild_sequence_view.py
+++ b/tests/metagpt/actions/test_rebuild_sequence_view.py
@@ -18,6 +18,7 @@ from metagpt.utils.git_repository import ChangeType
from metagpt.utils.graph_repository import SPO
+@pytest.mark.skip
@pytest.mark.asyncio
async def test_rebuild(context, mocker):
# Mock
@@ -47,6 +48,8 @@ async def test_rebuild(context, mocker):
context=context,
)
await action.run()
+ rows = await action.graph_db.select()
+ assert rows
assert context.repo.docs.graph_repo.changed_files
diff --git a/tests/metagpt/environment/mincraft_env/test_mincraft_ext_env.py b/tests/metagpt/environment/mincraft_env/test_mincraft_ext_env.py
deleted file mode 100644
index ad3376141..000000000
--- a/tests/metagpt/environment/mincraft_env/test_mincraft_ext_env.py
+++ /dev/null
@@ -1,14 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# @Desc : the unittest of MincraftExtEnv
-
-
-from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
-from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
-
-
-def test_mincraft_ext_env():
- ext_env = MincraftExtEnv()
- assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
- assert MC_CKPT_DIR.joinpath("skill/code").exists()
- assert ext_env.warm_up.get("optional_inventory_items") == 7
diff --git a/tests/metagpt/environment/mincraft_env/__init__.py b/tests/metagpt/environment/minecraft_env/__init__.py
similarity index 100%
rename from tests/metagpt/environment/mincraft_env/__init__.py
rename to tests/metagpt/environment/minecraft_env/__init__.py
diff --git a/tests/metagpt/environment/minecraft_env/test_minecraft_ext_env.py b/tests/metagpt/environment/minecraft_env/test_minecraft_ext_env.py
new file mode 100644
index 000000000..ed08b719f
--- /dev/null
+++ b/tests/metagpt/environment/minecraft_env/test_minecraft_ext_env.py
@@ -0,0 +1,14 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : the unittest of MinecraftExtEnv
+
+
+from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
+from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
+
+
+def test_minecraft_ext_env():
+ ext_env = MinecraftExtEnv()
+ assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
+ assert MC_CKPT_DIR.joinpath("skill/code").exists()
+ assert ext_env.warm_up.get("optional_inventory_items") == 7
diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py
index bff8dbde4..40a9fda92 100644
--- a/tests/metagpt/provider/test_base_llm.py
+++ b/tests/metagpt/provider/test_base_llm.py
@@ -11,6 +11,7 @@ import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
+from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
get_part_chat_completion,
@@ -22,7 +23,7 @@ name = "GPT"
class MockBaseLLM(BaseLLM):
def __init__(self, config: LLMConfig = None):
- pass
+ self.config = config or mock_llm_config
def completion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(name)
diff --git a/tests/metagpt/roles/di/test_data_interpreter.py b/tests/metagpt/roles/di/test_data_interpreter.py
index ba50f473b..e5cc5b29b 100644
--- a/tests/metagpt/roles/di/test_data_interpreter.py
+++ b/tests/metagpt/roles/di/test_data_interpreter.py
@@ -10,10 +10,9 @@ async def test_interpreter(mocker, auto_run):
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
mocker.patch("builtins.input", return_value="confirm")
- requirement = "Run data analysis on sklearn Iris dataset, include a plot"
- tools = []
+ requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
- di = DataInterpreter(auto_run=auto_run, use_tools=True, tools=tools)
+ di = DataInterpreter(auto_run=auto_run)
rsp = await di.run(requirement)
logger.info(rsp)
assert len(rsp.content) > 0
@@ -21,3 +20,15 @@ async def test_interpreter(mocker, auto_run):
finished_tasks = di.planner.plan.get_finished_tasks()
assert len(finished_tasks) > 0
assert len(finished_tasks[0].code) > 0 # check one task to see if code is recorded
+
+
+@pytest.mark.asyncio
+async def test_interpreter_react_mode(mocker):
+ mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
+
+ requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
+
+ di = DataInterpreter(react_mode="react")
+ rsp = await di.run(requirement)
+ logger.info(rsp)
+ assert len(rsp.content) > 0
diff --git a/tests/metagpt/roles/di/test_ml_engineer.py b/tests/metagpt/roles/di/test_ml_engineer.py
deleted file mode 100644
index 08b92cd27..000000000
--- a/tests/metagpt/roles/di/test_ml_engineer.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import pytest
-
-from metagpt.actions.di.execute_nb_code import ExecuteNbCode
-from metagpt.logs import logger
-from metagpt.roles.di.ml_engineer import MLEngineer
-from metagpt.schema import Message, Plan, Task
-from metagpt.tools.tool_type import ToolType
-from tests.metagpt.actions.di.test_debug_code import CODE, DebugContext, ErrorStr
-
-
-def test_mle_init():
- mle = MLEngineer(goal="test", auto_run=True, use_tools=True, tools=["tool1", "tool2"])
- assert mle.tools == []
-
-
-MockPlan = Plan(
- goal="This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: 'tests/data/ml_datasets/titanic/split_train.csv', eval data path: 'tests/data/ml_datasets/titanic/split_eval.csv'.",
- context="",
- tasks=[
- Task(
- task_id="1",
- dependent_task_ids=[],
- instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
- task_type="eda",
- code="",
- result="",
- is_success=False,
- is_finished=False,
- )
- ],
- task_map={
- "1": Task(
- task_id="1",
- dependent_task_ids=[],
- instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
- task_type="eda",
- code="",
- result="",
- is_success=False,
- is_finished=False,
- )
- },
- current_task_id="1",
-)
-
-
-@pytest.mark.asyncio
-async def test_mle_write_code(mocker):
- data_path = "tests/data/ml_datasets/titanic"
-
- mle = MLEngineer(auto_run=True, use_tools=True)
- mle.planner.plan = MockPlan
-
- code, _ = await mle._write_code()
- assert data_path in code["code"]
-
-
-@pytest.mark.asyncio
-async def test_mle_update_data_columns(mocker):
- mle = MLEngineer(auto_run=True, use_tools=True)
- mle.planner.plan = MockPlan
-
- # manually update task type to test update
- mle.planner.plan.current_task.task_type = ToolType.DATA_PREPROCESS.value
-
- result = await mle._update_data_columns()
- assert result is not None
-
-
-@pytest.mark.asyncio
-async def test_mle_debug_code(mocker):
- mle = MLEngineer(auto_run=True, use_tools=True)
- mle.working_memory.add(Message(content=ErrorStr, cause_by=ExecuteNbCode))
- mle.latest_code = CODE
- mle.debug_context = DebugContext
- code, _ = await mle._write_code()
- assert len(code) > 0
-
-
-@pytest.mark.skip
-@pytest.mark.asyncio
-async def test_ml_engineer():
- data_path = "tests/data/ml_datasets/titanic"
- requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
- tools = ["FillMissingValue", "CatCross", "dummy_tool"]
-
- mle = MLEngineer(auto_run=True, use_tools=True, tools=tools)
- rsp = await mle.run(requirement)
- logger.info(rsp)
- assert len(rsp.content) > 0
diff --git a/tests/metagpt/roles/test_tutorial_assistant.py b/tests/metagpt/roles/test_tutorial_assistant.py
index c12c2b26e..732f346fd 100644
--- a/tests/metagpt/roles/test_tutorial_assistant.py
+++ b/tests/metagpt/roles/test_tutorial_assistant.py
@@ -6,11 +6,11 @@
@File : test_tutorial_assistant.py
"""
-import aiofiles
import pytest
from metagpt.const import TUTORIAL_PATH
from metagpt.roles.tutorial_assistant import TutorialAssistant
+from metagpt.utils.common import aread
@pytest.mark.asyncio
@@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context):
msg = await role.run(topic)
assert TUTORIAL_PATH.exists()
filename = msg.content
- async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
- content = await reader.read()
- assert "pip" in content
+ content = await aread(filename=filename)
+ assert "pip" in content
if __name__ == "__main__":
diff --git a/tests/metagpt/strategy/test_planner.py b/tests/metagpt/strategy/test_planner.py
new file mode 100644
index 000000000..ff1c6da3f
--- /dev/null
+++ b/tests/metagpt/strategy/test_planner.py
@@ -0,0 +1,37 @@
+from metagpt.schema import Plan, Task
+from metagpt.strategy.planner import Planner
+from metagpt.strategy.task_type import TaskType
+
+MOCK_TASK_MAP = {
+ "1": Task(
+ task_id="1",
+ instruction="test instruction for finished task",
+ task_type=TaskType.EDA.type_name,
+ dependent_task_ids=[],
+ code="some finished test code",
+ result="some finished test result",
+ is_finished=True,
+ ),
+ "2": Task(
+ task_id="2",
+ instruction="test instruction for current task",
+ task_type=TaskType.DATA_PREPROCESS.type_name,
+ dependent_task_ids=["1"],
+ ),
+}
+MOCK_PLAN = Plan(
+ goal="test goal",
+ tasks=list(MOCK_TASK_MAP.values()),
+ task_map=MOCK_TASK_MAP,
+ current_task_id="2",
+)
+
+
+def test_planner_get_plan_status():
+ planner = Planner(plan=MOCK_PLAN)
+ status = planner.get_plan_status()
+
+ assert "some finished test code" in status
+ assert "some finished test result" in status
+ assert "test instruction for current task" in status
+ assert TaskType.DATA_PREPROCESS.value.guidance in status # current task guidance
diff --git a/tests/metagpt/tools/libs/test_gpt_v_generator.py b/tests/metagpt/tools/libs/test_gpt_v_generator.py
index 907006765..4a2e68682 100644
--- a/tests/metagpt/tools/libs/test_gpt_v_generator.py
+++ b/tests/metagpt/tools/libs/test_gpt_v_generator.py
@@ -60,18 +60,24 @@ async def test_generate_webpages(mock_webpage_filename_with_styles_and_scripts,
async def test_save_webpages_with_styles_and_scripts(mock_webpage_filename_with_styles_and_scripts, image_path):
generator = GPTvGenerator()
webpages = await generator.generate_webpages(image_path)
- webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
+ webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_1")
logs.logger.info(webpages_dir)
assert webpages_dir.exists()
+ assert (webpages_dir / "index.html").exists()
+ assert (webpages_dir / "styles.css").exists()
+ assert (webpages_dir / "scripts.js").exists()
@pytest.mark.asyncio
async def test_save_webpages_with_style_and_script(mock_webpage_filename_with_style_and_script, image_path):
generator = GPTvGenerator()
webpages = await generator.generate_webpages(image_path)
- webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
+ webpages_dir = generator.save_webpages(webpages=webpages, save_folder_name="test_2")
logs.logger.info(webpages_dir)
assert webpages_dir.exists()
+ assert (webpages_dir / "index.html").exists()
+ assert (webpages_dir / "style.css").exists()
+ assert (webpages_dir / "script.js").exists()
@pytest.mark.asyncio
diff --git a/tests/metagpt/tools/libs/test_web_scraping.py b/tests/metagpt/tools/libs/test_web_scraping.py
index c11960e68..3d8877b8d 100644
--- a/tests/metagpt/tools/libs/test_web_scraping.py
+++ b/tests/metagpt/tools/libs/test_web_scraping.py
@@ -4,8 +4,8 @@ from metagpt.tools.libs.web_scraping import scrape_web_playwright
@pytest.mark.asyncio
-async def test_scrape_web_playwright():
- test_url = "https://www.deepwisdom.ai"
+async def test_scrape_web_playwright(http_server):
+ server, test_url = await http_server()
result = await scrape_web_playwright(test_url)
@@ -21,3 +21,4 @@ async def test_scrape_web_playwright():
assert not result["inner_text"].endswith(" ")
assert not result["html"].startswith(" ")
assert not result["html"].endswith(" ")
+ await server.stop()
diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py
index a1f03ef7b..964ead02f 100644
--- a/tests/metagpt/tools/test_search_engine.py
+++ b/tests/metagpt/tools/test_search_engine.py
@@ -11,7 +11,6 @@ from typing import Callable
import pytest
-from metagpt.config2 import config
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
@@ -53,14 +52,11 @@ async def test_search_engine(
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
- assert config.search
search_engine_config["api_key"] = "mock-serpapi-key"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
- assert config.search
search_engine_config["api_key"] = "mock-google-key"
search_engine_config["cse_id"] = "mock-google-cse"
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
- assert config.search
search_engine_config["api_key"] = "mock-serper-key"
async def test(search_engine):
diff --git a/tests/metagpt/tools/test_tool_convert.py b/tests/metagpt/tools/test_tool_convert.py
index 8f26a211c..061a619ce 100644
--- a/tests/metagpt/tools/test_tool_convert.py
+++ b/tests/metagpt/tools/test_tool_convert.py
@@ -1,44 +1,8 @@
+from typing import Literal, Union
+
import pandas as pd
-from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
-
-
-def test_docstring_to_schema():
- docstring = """
- Some test desc.
-
- Args:
- features (list): Columns to be processed.
- strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
- used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
- fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
- Defaults to None.
- Returns:
- pd.DataFrame: The transformed DataFrame.
- """
- expected = {
- "description": "Some test desc.",
- "parameters": {
- "properties": {
- "features": {"type": "list", "description": "Columns to be processed."},
- "strategy": {
- "type": "str",
- "description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
- "default": "'mean'",
- "enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
- },
- "fill_value": {
- "type": "int",
- "description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
- "default": "None",
- },
- },
- "required": ["features"],
- },
- "returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
- }
- schema = docstring_to_schema(docstring)
- assert schema == expected
+from metagpt.tools.tool_convert import convert_code_to_tool_schema
class DummyClass:
@@ -81,12 +45,25 @@ class DummyClass:
pass
-def dummy_fn(df: pd.DataFrame) -> dict:
+def dummy_fn(
+ df: pd.DataFrame,
+ s: str,
+ k: int = 5,
+ type: Literal["a", "b", "c"] = "a",
+ test_dict: dict[str, int] = None,
+ test_union: Union[str, list[str]] = "",
+) -> dict:
"""
Analyzes a DataFrame and categorizes its columns based on data types.
Args:
- df (pd.DataFrame): The DataFrame to be analyzed.
+ df: The DataFrame to be analyzed.
+ Another line for df.
+ s (str): Some test string param.
+ Another line for s.
+ k (int, optional): Some test integer param. Defaults to 5.
+ type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'.
+ more_args: will be omitted here for testing
Returns:
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
@@ -115,41 +92,21 @@ def test_convert_code_to_tool_schema_class():
"methods": {
"__init__": {
"type": "function",
- "description": "Initialize self.",
- "parameters": {
- "properties": {
- "features": {"type": "list", "description": "Columns to be processed."},
- "strategy": {
- "type": "str",
- "description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
- "default": "'mean'",
- "enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
- },
- "fill_value": {
- "type": "int",
- "description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
- "default": "None",
- },
- },
- "required": ["features"],
- },
+ "description": "Initialize self. ",
+ "signature": "(self, features: list, strategy: str = 'mean', fill_value=None)",
+ "parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
},
"fit": {
"type": "function",
- "description": "Fit the FillMissingValue model.",
- "parameters": {
- "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
- "required": ["df"],
- },
+ "description": "Fit the FillMissingValue model. ",
+ "signature": "(self, df: pandas.core.frame.DataFrame)",
+ "parameters": "Args: df (pd.DataFrame): The input DataFrame.",
},
"transform": {
"type": "function",
- "description": "Transform the input DataFrame with the fitted model.",
- "parameters": {
- "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
- "required": ["df"],
- },
- "returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
+ "description": "Transform the input DataFrame with the fitted model. ",
+ "signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame",
+ "parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.",
},
},
}
@@ -160,11 +117,9 @@ def test_convert_code_to_tool_schema_class():
def test_convert_code_to_tool_schema_function():
expected = {
"type": "function",
- "description": "Analyzes a DataFrame and categorizes its columns based on data types.",
- "parameters": {
- "properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
- "required": ["df"],
- },
+ "description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
+ "signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict",
+ "parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.",
}
schema = convert_code_to_tool_schema(dummy_fn)
assert schema == expected
diff --git a/tests/metagpt/tools/test_tool_recommend.py b/tests/metagpt/tools/test_tool_recommend.py
new file mode 100644
index 000000000..fafe0a638
--- /dev/null
+++ b/tests/metagpt/tools/test_tool_recommend.py
@@ -0,0 +1,90 @@
+import pytest
+
+from metagpt.schema import Plan, Task
+from metagpt.tools import TOOL_REGISTRY
+from metagpt.tools.tool_recommend import (
+ BM25ToolRecommender,
+ ToolRecommender,
+ TypeMatchToolRecommender,
+)
+
+
+@pytest.fixture
+def mock_plan(mocker):
+ task_map = {
+ "1": Task(
+ task_id="1",
+ instruction="conduct feature engineering, add new features on the dataset",
+ task_type="feature engineering",
+ )
+ }
+ plan = Plan(
+ goal="test requirement",
+ tasks=list(task_map.values()),
+ task_map=task_map,
+ current_task_id="1",
+ )
+ return plan
+
+
+@pytest.fixture
+def mock_bm25_tr(mocker):
+ tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
+ return tr
+
+
+def test_tr_init():
+ tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"])
+ # web_scraping is a tool tag, it has one tool scrape_web_playwright
+ assert list(tr.tools.keys()) == [
+ "FillMissingValue",
+ "PolynomialExpansion",
+ "scrape_web_playwright",
+ ]
+
+
+def test_tr_init_default_tools_value():
+ tr = ToolRecommender()
+ assert tr.tools == {}
+
+
+def test_tr_init_tools_all():
+ tr = ToolRecommender(tools=[""])
+ assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys())
+
+
+@pytest.mark.asyncio
+async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr):
+ result = await mock_bm25_tr.recall_tools(plan=mock_plan)
+ assert len(result) == 3
+ assert result[0].name == "PolynomialExpansion"
+
+
+@pytest.mark.asyncio
+async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr):
+ result = await mock_bm25_tr.recall_tools(
+ context="conduct feature engineering, add new features on the dataset", plan=None
+ )
+ assert len(result) == 3
+ assert result[0].name == "PolynomialExpansion"
+
+
+@pytest.mark.asyncio
+async def test_bm25_recommend_tools(mock_bm25_tr):
+ result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset")
+ assert len(result) == 2 # web scraping tool should be filtered out at rank stage
+ assert result[0].name == "PolynomialExpansion"
+
+
+@pytest.mark.asyncio
+async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr):
+ result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan)
+ assert isinstance(result, str)
+
+
+@pytest.mark.asyncio
+async def test_tm_tr_recall_with_plan(mock_plan, mock_bm25_tr):
+ tr = TypeMatchToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
+ result = await tr.recall_tools(plan=mock_plan)
+ assert len(result) == 1
+ assert result[0].name == "PolynomialExpansion"
diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py
index 2fd487fb7..f44dfea0b 100644
--- a/tests/metagpt/tools/test_tool_registry.py
+++ b/tests/metagpt/tools/test_tool_registry.py
@@ -1,7 +1,6 @@
import pytest
from metagpt.tools.tool_registry import ToolRegistry
-from metagpt.tools.tool_type import ToolType
@pytest.fixture
@@ -9,25 +8,11 @@ def tool_registry():
return ToolRegistry()
-@pytest.fixture
-def tool_registry_full():
- return ToolRegistry(tool_types=ToolType)
-
-
# Test Initialization
def test_initialization(tool_registry):
assert isinstance(tool_registry, ToolRegistry)
assert tool_registry.tools == {}
- assert tool_registry.tool_types == {}
- assert tool_registry.tools_by_types == {}
-
-
-# Test Initialization with tool types
-def test_initialize_with_tool_types(tool_registry_full):
- assert isinstance(tool_registry_full, ToolRegistry)
- assert tool_registry_full.tools == {}
- assert tool_registry_full.tools_by_types == {}
- assert "data_preprocess" in tool_registry_full.tool_types
+ assert tool_registry.tools_by_tags == {}
class TestClassTool:
@@ -72,31 +57,24 @@ def test_get_tool(tool_registry):
assert "description" in tool.schemas
-# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
-def test_has_tool_type(tool_registry_full):
- assert tool_registry_full.has_tool_type("data_preprocess")
- assert not tool_registry_full.has_tool_type("NonexistentType")
+def test_has_tool_tag(tool_registry):
+ tool_registry.register_tool(
+ "TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"]
+ )
+ assert tool_registry.has_tool_tag("test")
+ assert not tool_registry.has_tool_tag("Non-existent tag")
-def test_get_tool_type(tool_registry_full):
- retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
- assert retrieved_type is not None
- assert retrieved_type.name == "data_preprocess"
-
-
-def test_get_tools_by_type(tool_registry):
- tool_type_name = "TestType"
+def test_get_tools_by_tag(tool_registry):
+ tool_tag_name = "Test Tag"
tool_name = "TestTool"
tool_path = "/path/to/tool"
- tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
+ tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool)
- tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
- assert tools_by_type is not None
- assert tool_name in tools_by_type
+ tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name)
+ assert tools_by_tag is not None
+ assert tool_name in tools_by_tag
-
-# Test case for when the tool type does not exist
-def test_get_tools_by_nonexistent_type(tool_registry):
- tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
- assert not tools_by_type
+ tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag")
+ assert not tools_by_tag_non_existent
diff --git a/tests/metagpt/tools/test_web_browser_engine.py b/tests/metagpt/tools/test_web_browser_engine.py
index ceebd67fc..7a344e0ad 100644
--- a/tests/metagpt/tools/test_web_browser_engine.py
+++ b/tests/metagpt/tools/test_web_browser_engine.py
@@ -9,14 +9,16 @@ from metagpt.utils.parse_html import WebPage
@pytest.mark.asyncio
@pytest.mark.parametrize(
- "browser_type, url, urls",
+ "browser_type",
[
- (WebBrowserEngineType.PLAYWRIGHT, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
- (WebBrowserEngineType.SELENIUM, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
+ WebBrowserEngineType.PLAYWRIGHT,
+ WebBrowserEngineType.SELENIUM,
],
ids=["playwright", "selenium"],
)
-async def test_scrape_web_page(browser_type, url, urls):
+async def test_scrape_web_page(browser_type, http_server):
+ server, url = await http_server()
+ urls = [url, url, url]
browser = web_browser_engine.WebBrowserEngine(engine=browser_type)
result = await browser.run(url)
assert isinstance(result, WebPage)
@@ -27,6 +29,7 @@ async def test_scrape_web_page(browser_type, url, urls):
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
+ await server.stop()
if __name__ == "__main__":
diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py
index f35848cf4..12ea96d7b 100644
--- a/tests/metagpt/tools/test_web_browser_engine_playwright.py
+++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py
@@ -9,18 +9,28 @@ from metagpt.utils.parse_html import WebPage
@pytest.mark.asyncio
@pytest.mark.parametrize(
- "browser_type, use_proxy, kwagrs, url, urls",
+ "browser_type, use_proxy, kwagrs,",
[
- ("chromium", {"proxy": True}, {}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
- ("firefox", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
- ("webkit", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
+ ("chromium", {"proxy": True}, {}),
+ (
+ "firefox",
+ {},
+ {"ignore_https_errors": True},
+ ),
+ (
+ "webkit",
+ {},
+ {"ignore_https_errors": True},
+ ),
],
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
)
-async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
+async def test_scrape_web_page(browser_type, use_proxy, kwagrs, proxy, capfd, http_server):
+ server, url = await http_server()
+ urls = [url, url, url]
proxy_url = None
if use_proxy:
- server, proxy_url = await proxy()
+ proxy_server, proxy_url = await proxy()
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs)
result = await browser.run(url)
assert isinstance(result, WebPage)
@@ -32,8 +42,10 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
- server.close()
+ proxy_server.close()
+ await proxy_server.wait_closed()
assert "Proxy:" in capfd.readouterr().out
+ await server.stop()
if __name__ == "__main__":
diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py
index a88a5d0f4..a214748bd 100644
--- a/tests/metagpt/tools/test_web_browser_engine_selenium.py
+++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+
import browsers
import pytest
@@ -10,51 +11,48 @@ from metagpt.utils.parse_html import WebPage
@pytest.mark.asyncio
@pytest.mark.parametrize(
- "browser_type, use_proxy, url, urls",
+ "browser_type, use_proxy,",
[
pytest.param(
"chrome",
- True,
- "https://deepwisdom.ai",
- ("https://deepwisdom.ai",),
+ False,
marks=pytest.mark.skipif(not browsers.get("chrome"), reason="chrome browser not found"),
),
pytest.param(
"firefox",
False,
- "https://deepwisdom.ai",
- ("https://deepwisdom.ai",),
marks=pytest.mark.skipif(not browsers.get("firefox"), reason="firefox browser not found"),
),
pytest.param(
"edge",
False,
- "https://deepwisdom.ai",
- ("https://deepwisdom.ai",),
marks=pytest.mark.skipif(not browsers.get("msedge"), reason="edge browser not found"),
),
],
ids=["chrome-normal", "firefox-normal", "edge-normal"],
)
-async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
+async def test_scrape_web_page(browser_type, use_proxy, proxy, capfd, http_server):
# Prerequisites
# firefox, chrome, Microsoft Edge
+ server, url = await http_server()
+ urls = [url, url, url]
proxy_url = None
if use_proxy:
- server, proxy_url = await proxy()
+ proxy_server, proxy_url = await proxy()
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url)
result = await browser.run(url)
assert isinstance(result, WebPage)
assert "MetaGPT" in result.inner_text
- if urls:
- results = await browser.run(url, *urls)
- assert isinstance(results, list)
- assert len(results) == len(urls) + 1
- assert all(("MetaGPT" in i.inner_text) for i in results)
+ results = await browser.run(url, *urls)
+ assert isinstance(results, list)
+ assert len(results) == len(urls) + 1
+ assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
- server.close()
- assert "Proxy:" in capfd.readouterr().out
+ proxy_server.close()
+ await proxy_server.wait_closed()
+ assert "Proxy: localhost" in capfd.readouterr().out
+ await server.stop()
if __name__ == "__main__":
diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py
index b365f424f..75e8ef4ad 100644
--- a/tests/metagpt/utils/test_common.py
+++ b/tests/metagpt/utils/test_common.py
@@ -13,7 +13,6 @@ import uuid
from pathlib import Path
from typing import Any, Set
-import aiofiles
import pytest
from pydantic import BaseModel
@@ -125,9 +124,7 @@ class TestGetProjectRoot:
async def test_parse_data_exception(self, filename, want):
pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename
assert pathname.exists()
- async with aiofiles.open(str(pathname), mode="r") as reader:
- data = await reader.read()
-
+ data = await aread(filename=pathname)
result = OutputParser.parse_data(data=data)
assert want in result
@@ -198,12 +195,25 @@ class TestGetProjectRoot:
@pytest.mark.asyncio
async def test_read_write(self):
- pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
+ pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp"
await awrite(pathname, "ABC")
data = await aread(pathname)
assert data == "ABC"
pathname.unlink(missing_ok=True)
+ @pytest.mark.asyncio
+ async def test_read_write_error_charset(self):
+ pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt"
+ content = "中国abc123\u27f6"
+ await awrite(filename=pathname, data=content)
+ data = await aread(filename=pathname)
+ assert data == content
+
+ content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。"
+ await awrite(filename=pathname, data=content, encoding="gb2312")
+ data = await aread(filename=pathname, encoding="utf-8")
+ assert data == content
+
if __name__ == "__main__":
pytest.main([__file__, "-s"])
diff --git a/tests/metagpt/utils/test_git_repository.py b/tests/metagpt/utils/test_git_repository.py
index ea28b8f0b..480a22e24 100644
--- a/tests/metagpt/utils/test_git_repository.py
+++ b/tests/metagpt/utils/test_git_repository.py
@@ -10,15 +10,14 @@
import shutil
from pathlib import Path
-import aiofiles
import pytest
+from metagpt.utils.common import awrite
from metagpt.utils.git_repository import GitRepository
async def mock_file(filename, content=""):
- async with aiofiles.open(str(filename), mode="w") as file:
- await file.write(content)
+ await awrite(filename=filename, data=content)
async def mock_repo(local_path) -> (GitRepository, Path):
diff --git a/tests/metagpt/utils/test_repo_to_markdown.py b/tests/metagpt/utils/test_repo_to_markdown.py
new file mode 100644
index 000000000..914c50dd7
--- /dev/null
+++ b/tests/metagpt/utils/test_repo_to_markdown.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import uuid
+from pathlib import Path
+
+import pytest
+
+from metagpt.utils.repo_to_markdown import repo_to_markdown
+
+
+@pytest.mark.parametrize(
+ ["repo_path", "output"],
+ [(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
+)
+@pytest.mark.asyncio
+async def test_repo_to_markdown(repo_path: Path, output: Path):
+ markdown = await repo_to_markdown(repo_path=repo_path, output=output)
+ assert output.exists()
+ assert markdown
+
+ output.unlink(missing_ok=True)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-s"])
diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py
index 34c21612c..ef13c2325 100644
--- a/tests/metagpt/utils/test_s3.py
+++ b/tests/metagpt/utils/test_s3.py
@@ -9,7 +9,6 @@ import uuid
from pathlib import Path
import aioboto3
-import aiofiles
import pytest
from metagpt.config2 import Config
@@ -46,7 +45,7 @@ async def test_s3(mocker):
conn = S3(s3)
object_name = "unittest.bak"
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
- pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
+ pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak")
pathname.unlink(missing_ok=True)
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
assert pathname.exists()
@@ -54,8 +53,7 @@ async def test_s3(mocker):
assert url
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
assert bin_data
- async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
- data = await reader.read()
+ data = await aread(filename=__file__)
res = await conn.cache(data, ".bak", "script")
assert "http" in res
@@ -69,8 +67,6 @@ async def test_s3(mocker):
except Exception:
pass
- await reader.close()
-
if __name__ == "__main__":
pytest.main([__file__, "-s"])
diff --git a/tests/metagpt/utils/test_text.py b/tests/metagpt/utils/test_text.py
index c9a9753be..319a4ebe0 100644
--- a/tests/metagpt/utils/test_text.py
+++ b/tests/metagpt/utils/test_text.py
@@ -22,7 +22,7 @@ def _paragraphs(n):
@pytest.mark.parametrize(
"msgs, model_name, system_text, reserved, expected",
[
- (_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
+ (_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
(_msgs(), "gpt-4", "System", 2000, 3),
@@ -32,22 +32,23 @@ def _paragraphs(n):
],
)
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
- assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
+ length = len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000
+ assert length == expected
@pytest.mark.parametrize(
"text, prompt_template, model_name, system_text, reserved, expected",
[
- (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
+ (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
- (" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1000, 8),
+ (" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1000, 8),
],
)
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
- ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
- assert len(ret) == expected
+ chunk = len(list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved)))
+ assert chunk == expected
@pytest.mark.parametrize(
diff --git a/tests/metagpt/utils/test_tree.py b/tests/metagpt/utils/test_tree.py
new file mode 100644
index 000000000..03a2a5606
--- /dev/null
+++ b/tests/metagpt/utils/test_tree.py
@@ -0,0 +1,64 @@
+from pathlib import Path
+from typing import List
+
+import pytest
+
+from metagpt.utils.tree import _print_tree, tree
+
+
+@pytest.mark.parametrize(
+ ("root", "rules"),
+ [
+ (str(Path(__file__).parent / "../.."), None),
+ (str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
+ ],
+)
+def test_tree(root: str, rules: str):
+ v = tree(root=root, gitignore=rules)
+ assert v
+
+
+@pytest.mark.parametrize(
+ ("root", "rules"),
+ [
+ (str(Path(__file__).parent / "../.."), None),
+ (str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
+ ],
+)
+def test_tree_command(root: str, rules: str):
+ v = tree(root=root, gitignore=rules, run_command=True)
+ assert v
+
+
+@pytest.mark.parametrize(
+ ("tree", "want"),
+ [
+ ({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]),
+ ({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]),
+ (
+ {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}},
+ ["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"],
+ ),
+ (
+ {"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}},
+ [
+ "h",
+ "+-- a",
+ "| +-- b",
+ "| | +-- e",
+ "| | +-- f",
+ "| | +-- g",
+ "| +-- c",
+ "| +-- d",
+ "+-- i",
+ ],
+ ),
+ ],
+)
+def test__print_tree(tree: dict, want: List[str]):
+ v = _print_tree(tree)
+ assert v == want
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-s"])
diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py
index b2052e2b3..c4262e080 100644
--- a/tests/mock/mock_llm.py
+++ b/tests/mock/mock_llm.py
@@ -5,6 +5,7 @@ from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.logs import logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
+from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
@@ -32,14 +33,13 @@ class MockLLM(OriginalLLM):
async def original_aask(
self,
- msg: str,
+ msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
stream=True,
- ):
- """A copy of metagpt.provider.base_llm.BaseLLM.aask, we can't use super().aask because it will be mocked"""
+ ) -> str:
if system_msgs:
message = self._system_msgs(system_msgs)
else:
@@ -48,7 +48,11 @@ class MockLLM(OriginalLLM):
message = []
if format_msgs:
message.extend(format_msgs)
- message.append(self._user_msg(msg, images=images))
+ if isinstance(msg, str):
+ message.append(self._user_msg(msg, images=images))
+ else:
+ message.extend(msg)
+ logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
@@ -67,19 +71,27 @@ class MockLLM(OriginalLLM):
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
"""
+ if "tools" not in kwargs:
+ configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
+ kwargs.update(configs)
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask(
self,
- msg: str,
+ msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
stream=True,
) -> str:
- msg_key = msg # used to identify it a message has been called before
+ # used to identify it a message has been called before
+ if isinstance(msg, list):
+ msg_key = "#MSG_SEP#".join([m["content"] for m in msg])
+ else:
+ msg_key = msg
+
if system_msgs:
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
msg_key = joined_system_msg + msg_key
@@ -92,8 +104,7 @@ class MockLLM(OriginalLLM):
return rsp
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
- messages = self._process_message(messages)
- msg_key = json.dumps(messages, ensure_ascii=False)
+ msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False)
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
return rsp