diff --git a/.github/workflows/fulltest.yaml b/.github/workflows/fulltest.yaml
index 2ab6444fa..32eb3da00 100644
--- a/.github/workflows/fulltest.yaml
+++ b/.github/workflows/fulltest.yaml
@@ -79,8 +79,8 @@ jobs:
./tests/data/rsp_cache_new.json
retention-days: 3
if: ${{ always() }}
- - name: Upload coverage reports to Codecov
- uses: codecov/codecov-action@v3
- env:
- CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ always() }}
+ # - name: Upload coverage reports to Codecov
+ # uses: codecov/codecov-action@v3
+ # env:
+ # CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
+ # if: ${{ always() }}
diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml
index 25f82b1e6..1fd193b52 100644
--- a/.github/workflows/unittest.yaml
+++ b/.github/workflows/unittest.yaml
@@ -91,8 +91,8 @@ jobs:
./tests/data/rsp_cache_new.json
retention-days: 3
if: ${{ always() }}
- - name: Upload coverage reports to Codecov
- uses: codecov/codecov-action@v3
- env:
- CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ always() }}
+ # - name: Upload coverage reports to Codecov
+ # uses: codecov/codecov-action@v3
+ # env:
+ # CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
+ # if: ${{ always() }}
diff --git a/.gitignore b/.gitignore
index 0d6be14ad..6130fcd45 100644
--- a/.gitignore
+++ b/.gitignore
@@ -188,4 +188,5 @@ cov.xml
*-structure.json
*.dot
.python-version
+*.csv
metagpt/ext/sela/results/*
diff --git a/README.md b/README.md
index a151a1f0f..4b846795c 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
# MetaGPT: The Multi-Agent Framework
-
+
@@ -12,6 +12,7 @@ # MetaGPT: The Multi-Agent Framework
+
@@ -22,11 +23,13 @@ # MetaGPT: The Multi-Agent Framework
-
+
## News
-🚀 Mar. 29, 2024: [v0.8.0](https://github.com/geekan/MetaGPT/releases/tag/v0.8.0) released. Now you can use Data Interpreter ([arxiv](https://arxiv.org/abs/2402.18679), [example](https://docs.deepwisdom.ai/main/en/DataInterpreter/), [code](https://github.com/geekan/MetaGPT/tree/main/examples/di)) via pypi package import. Meanwhile, we integrated RAG module and supported multiple new LLMs.
+🚀 Oct. 29, 2024: We introduced three papers: [AFLOW](https://arxiv.org/abs/2410.10762), [FACT](https://arxiv.org/abs/2410.21012), and [SELA](https://arxiv.org/abs/2410.17238), check the [code](examples)!
+
+🚀 Mar. 29, 2024: [v0.8.0](https://github.com/geekan/MetaGPT/releases/tag/v0.8.0) released. Now you can use Data Interpreter ([arxiv](https://arxiv.org/abs/2402.18679), [example](https://docs.deepwisdom.ai/main/en/DataInterpreter/), [code](https://github.com/geekan/MetaGPT/tree/main/examples/di)) via pypi package import. Meanwhile, we integrated the RAG module and supported multiple new LLMs.
🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Data Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/di/README.md), a powerful agent capable of solving a wide range of real-world problems.
@@ -120,7 +123,7 @@ ### Usage
### QuickStart & Demo Video
-- Try it on [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT)
+- Try it on [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT-SoftwareCompany)
- [Matthew Berman: How To Install MetaGPT - Build A Startup With One Prompt!!](https://youtu.be/uT75J_KG_aY)
- [Official Demo Video](https://github.com/geekan/MetaGPT/assets/2707039/5e8c1062-8c35-440f-bb20-2b0320f8d27d)
@@ -140,7 +143,7 @@ ## Tutorial
- [Data Interpreter](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html)
- [Debate](https://docs.deepwisdom.ai/main/en/guide/use_cases/multi_agent/debate.html)
- [Researcher](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/researcher.html)
- - [Recepit Assistant](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/receipt_assistant.html)
+ - [Receipt Assistant](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/receipt_assistant.html)
- ❓ [FAQs](https://docs.deepwisdom.ai/main/en/guide/faq.html)
## Support
@@ -184,4 +187,13 @@ ## Citation
archivePrefix={arXiv},
primaryClass={cs.AI}
}
+@misc{zhang2024aflow,
+ title={AFlow: Automating Agentic Workflow Generation},
+ author={Jiayi Zhang and Jinyu Xiang and Zhaoyang Yu and Fengwei Teng and Xionghui Chen and Jiaqi Chen and Mingchen Zhuge and Xin Cheng and Sirui Hong and Jinlin Wang and Bingnan Zheng and Bang Liu and Yuyu Luo and Chenglin Wu},
+ year={2024},
+ eprint={2410.10762},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI},
+ url={https://arxiv.org/abs/2410.10762},
+}
```
diff --git a/docs/README_CN.md b/docs/README_CN.md
index 4e7866d83..88583cf24 100644
--- a/docs/README_CN.md
+++ b/docs/README_CN.md
@@ -9,19 +9,20 @@ # MetaGPT: 多智能体框架
-
-
-
+
+
+
+
-
+
-
+
1. MetaGPT输入**一句话的老板需求**,输出**用户故事 / 竞品分析 / 需求 / 数据结构 / APIs / 文件等**
@@ -76,7 +77,7 @@ # 步骤2: 使用容器运行metagpt演示
详细的安装请参考 [docker_install](https://docs.deepwisdom.ai/main/zh/guide/get_started/installation.html#%E4%BD%BF%E7%94%A8docker%E5%AE%89%E8%A3%85)
### 快速开始的演示视频
-- 在 [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT) 上进行体验
+- 在 [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT-SoftwareCompany) 上进行体验
- [Matthew Berman: How To Install MetaGPT - Build A Startup With One Prompt!!](https://youtu.be/uT75J_KG_aY)
- [官方演示视频](https://github.com/geekan/MetaGPT/assets/2707039/5e8c1062-8c35-440f-bb20-2b0320f8d27d)
diff --git a/docs/README_FR.md b/docs/README_FR.md
new file mode 100644
index 000000000..4bb02e0d4
--- /dev/null
+++ b/docs/README_FR.md
@@ -0,0 +1,194 @@
+
+# MetaGPT: Architecture Multi-Agent
+
+
+
+
+
+
+Assigner différents rôles aux GPTs pour former une entité collaborative capable de gérer des tâches complexes.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## Nouveautés
+🚀 29 mars 2024: La version [v0.8.0](https://github.com/geekan/MetaGPT/releases/tag/v0.8.0) a été publiée. Vous pouvez désormais utiliser le Data Interpreter ([arxiv](https://arxiv.org/abs/2402.18679), [example](https://docs.deepwisdom.ai/main/en/DataInterpreter/), [code](https://github.com/geekan/MetaGPT/tree/main/examples/di)) via l'importation du package PyPI. De plus, le module RAG (Génération Augmentée par Récupération) a été intégré, et plusieurs nouveaux modèles de LLMs sont désormais pris en charge.
+
+🚀 28 février 2024: La version [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) a été publiée, permettant l'attribution de différents modèles de langage (LLMs) à différents Rôles. Nous avons également introduit le [Data Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/di/README.md), , un agent puissant capable de résoudre une grande variété de problèmes du monde réel.
+
+🚀 16 janvier 2024: Notre article intitulé [MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework
+](https://openreview.net/forum?id=VtmBAGCN7o) a été accepté pour une **présentation orale (top 1,2%)** à la conférence ICLR 2024, se **classant n°1** dans la catégorie des agents basés sur les modèles de langage (LLM).
+
+🚀 3 janvier 2024 : La version [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0) a été publiée avec de nouvelles fonctionnalités, notamment la sérialisation, la mise à niveau du package OpenAI et la prise en charge de plusieurs modèles de langage (LLM). Un [exemple minimal pour le débat](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py) a également été ajouté pour illustrer ces capacités.
+
+🚀 15 décembre 2023 : La version [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) a été publiée, introduisant des fonctionnalités expérimentales telles que le développement incrémental, la prise en charge du multilingue, et la compatibilité avec plusieurs langages de programmation, etc..
+
+
+🔥 8 novembre 2023 : MetaGPT a été sélectionné parmi les [Open100: Top 100 des réalisations open source](https://www.benchcouncil.org/evaluation/opencs/annual.html), une reconnaissance qui met en avant les meilleures innovations et contributions dans le domaine des projets open source.
+
+🔥 1er septembre 2023 : MetaGPT a dominé le classement **GitHub Trending Monthly** pour la **17ème fois** en août 2023, consolidant ainsi sa position en tant que projet open source de premier plan.
+
+🌟 30 juin 2023 : MetaGPT est désormais open source, permettant à la communauté de contribuer et d'enrichir le projet.
+
+🌟 24 avril 2023 : La première ligne de code de MetaGPT a été engagée, marquant le début de ce projet innovant.
+
+
+### Système multi-agents dans une entreprise de logiciels
+
+1. **Exigence unique** : MetaGPT prend en entrée une **exigence formulée en une ligne** et produit des résultats variés, tels que des **user stories, des analyses concurrentielles, des exigences, des structures de données, des API, des documents, etc.**.
+
+2. **Structure interne** : MetaGPT intègre divers rôles présents dans une entreprise de logiciels, notamment **des chefs de produits, des architectes, des chefs de projet et des ingénieurs**. Ce système propose un processus complet de **développement logiciel**, soutenu par des **procédures opérationnelles standardisées (SOP) soigneusement orchestrées**.
+
+ 1. La philosophie centrale du système est exprimée par l'énoncé : `Code = SOP(Équipe)`. Cela signifie que les SOP sont concrétisées et appliquées à des équipes composées de modèles de langage (LLMs), permettant ainsi une meilleure gestion et un meilleur déroulement des projets.
+
+
+
+
+Schéma multi-agent d'une entreprise de logiciels (Mise en œuvre progressive)
+
+
+## Commençons !
+
+### Installation
+
+> Assurez-vous que Python 3.9 ou supérieur, mais inférieur à 3.12, est installé sur votre système. Vous pouvez le vérifier en utilisant : `python --version`.
+> Vous pouvez utiliser conda comme suit : `conda create -n metagpt python=3.9 && conda activate metagpt`
+
+```bash
+pip install --upgrade metagpt
+# or `pip install --upgrade git+https://github.com/geekan/MetaGPT.git`
+# or `git clone https://github.com/geekan/MetaGPT && cd MetaGPT && pip install --upgrade -e .`
+```
+
+Pour des conseils d'installation détaillés, veuillez vous référer à [cli_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-stable-version)
+ ou [docker_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-with-docker)
+
+### Configuration
+
+Vous pouvez initialiser la configuration de MetaGPT en lançant la commande suivante, ou en créant manuellement le fichier `~/.metagpt/config2.yaml` :
+```bash
+# Visitez https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html pour plus de détails
+metagpt --init-config # il créera ~/.metagpt/config2.yaml, il suffit de le modifier selon vos besoins
+```
+
+Vous pouvez configurer `~/.metagpt/config2.yaml` selon l'[exemple](https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml) et le [doc](https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html) :
+
+```yaml
+llm:
+ api_type: "openai" # ou azure / ollama / groq etc. Consultez LLMType pour plus d'options
+ model: "gpt-4-turbo" # ou gpt-3.5-turbo
+ base_url: "https://api.openai.com/v1" # ou URL de transfert / URL d'autre LLM.
+ api_key: "VOTRE_CLE_API"
+```
+
+### Utilisation
+
+Après l'installation, vous pouvez utiliser MetaGPT en CLI
+
+```bash
+metagpt "Create a 2048 game" # ceci créera un repo dans ./workspace
+```
+
+ou l'utiliser comme bibliothèque
+
+```python
+from metagpt.software_company import generate_repo, ProjectRepo
+repo: ProjectRepo = generate_repo("Create a 2048 game") # ou ProjectRepo("")
+print(repo) # il affichera la structure du repo avec les fichiers
+```
+
+Vous pouvez aussi utiliser [Data Interpreter](https://github.com/geekan/MetaGPT/tree/main/examples/di) pour écrire du code:
+
+```python
+import asyncio
+from metagpt.roles.di.data_interpreter import DataInterpreter
+
+async def main():
+ di = DataInterpreter()
+ await di.run("Exécuter une analyse de données sur le jeu de données sklearn Iris et y inclure un graphique")
+
+asyncio.run(main()) # ou attendre main() dans une configuration de notebook jupyter
+```
+
+
+### Vidéo de démonstration et de démarrage rapide (en Anglais) :
+- Essayez-le sur [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT)
+- [Matthew Berman : Comment installer MetaGPT - Construire une startup avec une seule invite](https://youtu.be/uT75J_KG_aY)
+- [Vidéo de démonstration officielle](https://github.com/geekan/MetaGPT/assets/2707039/5e8c1062-8c35-440f-bb20-2b0320f8d27d)
+
+https://github.com/geekan/MetaGPT/assets/34952977/34345016-5d13-489d-b9f9-b82ace413419
+
+## Tutoriel (en Anglais)
+
+- 🗒 [Document en ligne](https://docs.deepwisdom.ai/main/en/)
+- 💻 [Utilisation](https://docs.deepwisdom.ai/main/en/guide/get_started/quickstart.html)
+- 🔎 [Que peut faire MetaGPT](https://docs.deepwisdom.ai/main/en/guide/get_started/introduction.html)
+- 🛠 Comment créer ses propres agents ?
+ - [MetaGPT Guide d'utilisation et de développement | Agent 101](https://docs.deepwisdom.ai/main/en/guide/tutorials/agent_101.html)
+ - [MetaGPT Guide d'utilisation et de développement | MultiAgent 101](https://docs.deepwisdom.ai/main/en/guide/tutorials/multi_agent_101.html)
+- 🧑💻 Contribution
+ - [Élaborer une feuille de route](docs/ROADMAP.md)
+- 🔖 Cas d'usage
+ - [Interprète des données](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html)
+ - [Débat](https://docs.deepwisdom.ai/main/en/guide/use_cases/multi_agent/debate.html)
+ - [Chercheur](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/researcher.html)
+ - [Assistant(e) de réception](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/receipt_assistant.html)
+- ❓ [FAQs](https://docs.deepwisdom.ai/main/en/guide/faq.html)
+
+## Support
+
+### Rejoignez-nous sur Discord
+
+📢 Rejoignez-nous sur [Discord Channel](https://discord.gg/ZRHeExS6xv)! Au plaisir de vous y voir ! 🎉
+
+### Formulaire de contribution
+
+📝 [Remplissez le formulaire](https://airtable.com/appInfdG0eJ9J4NNL/pagK3Fh1sGclBvVkV/form) pour devenir contributeur. Nous nous réjouissons de votre participation !
+
+### Information de contact
+
+Si vous avez des questions ou des commentaires sur ce projet, n'hésitez pas à nous contacter. Nous apprécions grandement vos suggestions !
+
+- **Email:** alexanderwu@deepwisdom.ai
+- **GitHub Issues:** Pour des questions plus techniques, vous pouvez également créer un nouveau problème dans notre [dépôt Github](https://github.com/geekan/metagpt/issues).
+
+Nous répondrons à toutes les questions dans un délai de 2 à 3 jours ouvrables.
+
+## Citation
+
+Pour rester informé des dernières recherches et développements, suivez [@MetaGPT_] (https://twitter.com/MetaGPT_) sur Twitter.
+
+Pour citer [MetaGPT](https://openreview.net/forum?id=VtmBAGCN7o) ou [Data Interpreter](https://arxiv.org/abs/2402.18679) dans des publications, veuillez utiliser les entrées BibTeX suivantes.
+
+```bibtex
+@inproceedings{hong2024metagpt,
+ title={Meta{GPT}: Meta Programming for A Multi-Agent Collaborative Framework},
+ author={Sirui Hong and Mingchen Zhuge and Jonathan Chen and Xiawu Zheng and Yuheng Cheng and Jinlin Wang and Ceyao Zhang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu and J{\"u}rgen Schmidhuber},
+ booktitle={The Twelfth International Conference on Learning Representations},
+ year={2024},
+ url={https://openreview.net/forum?id=VtmBAGCN7o}
+}
+@misc{hong2024data,
+ title={Data Interpreter: An LLM Agent For Data Science},
+ author={Sirui Hong and Yizhang Lin and Bang Liu and Bangbang Liu and Binhao Wu and Danyang Li and Jiaqi Chen and Jiayi Zhang and Jinlin Wang and Li Zhang and Lingyao Zhang and Min Yang and Mingchen Zhuge and Taicheng Guo and Tuo Zhou and Wei Tao and Wenyi Wang and Xiangru Tang and Xiangtao Lu and Xiawu Zheng and Xinbing Liang and Yaying Fei and Yuheng Cheng and Zongze Xu and Chenglin Wu},
+ year={2024},
+ eprint={2402.18679},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI}
+}
+```
diff --git a/docs/README_JA.md b/docs/README_JA.md
index 8981361a8..fd96602b5 100644
--- a/docs/README_JA.md
+++ b/docs/README_JA.md
@@ -9,9 +9,10 @@ # MetaGPT: マルチエージェントフレームワーク
-
-
-
+
+
+
+
@@ -21,7 +22,7 @@ # MetaGPT: マルチエージェントフレームワーク
-
+
1. MetaGPT は、**1 行の要件** を入力とし、**ユーザーストーリー / 競合分析 / 要件 / データ構造 / API / 文書など** を出力します。
@@ -291,7 +292,7 @@ ## クイックスタート
- [MetaGPT クイックスタート](https://deepwisdom.feishu.cn/wiki/CyY9wdJc4iNqArku3Lncl4v8n2b)
Hugging Face Space で試す
-- https://huggingface.co/spaces/deepwisdom/MetaGPT
+- https://huggingface.co/spaces/deepwisdom/MetaGPT-SoftwareCompany
## 引用
diff --git a/docs/resources/aflow/AFLOW-experiment.jpg b/docs/resources/aflow/AFLOW-experiment.jpg
new file mode 100644
index 000000000..dc7266c1e
Binary files /dev/null and b/docs/resources/aflow/AFLOW-experiment.jpg differ
diff --git a/docs/resources/aflow/AFLOW-method.jpg b/docs/resources/aflow/AFLOW-method.jpg
new file mode 100644
index 000000000..14ae60f49
Binary files /dev/null and b/docs/resources/aflow/AFLOW-method.jpg differ
diff --git a/docs/resources/aflow/AFLOW-performance.jpg b/docs/resources/aflow/AFLOW-performance.jpg
new file mode 100644
index 000000000..3866c40b9
Binary files /dev/null and b/docs/resources/aflow/AFLOW-performance.jpg differ
diff --git a/examples/aflow/README.md b/examples/aflow/README.md
new file mode 100644
index 000000000..332cc4b3d
--- /dev/null
+++ b/examples/aflow/README.md
@@ -0,0 +1,88 @@
+# AFlow: Automating Agentic Workflow Generation
+
+AFlow is a framework for automatically generating and optimizing Agentic Workflows. It uses Monte Carlo tree search in a code-represented workflow space to find effective workflows, replacing manual development with machine effort. Our approach shows potential to outperform handcrafted workflows on various tasks.
+
+[Read our paper on arXiv](https://arxiv.org/abs/2410.10762)
+
+
+
+
+
+## Framework Components
+
+- **Node**: Basic unit of LLM invocation. See `metagpt/actions/action_node.py` for a flexible interface to control LLM, temperature, format, and prompt.
+- **Operator**: Predefined combinations of Nodes to enhance search efficiency. Encapsulates common operations like Generate, Format, Review, Revise, Ensemble, Test, and Programmer. See `metagpt/ext/aflow/operator.py` for details. You can customize your own Operator by referencing the implementations in this code.
+- **Workflow**: A sequence of LLM-invoking nodes connected by edges. Can be represented as graphs, neural networks, or code to express various execution structures. See `metagpt/ext/aflow/workflow.py` for our implementation.
+- **Optimizer**: Uses LLMs within a Monte Carlo Tree Search variant to explore and refine workflows. Iteratively selects, expands, evaluates, and updates workflows based on performance. See `metagpt/ext/aflow/scripts/optimizer.py` for details.
+- **Evaluator**: Assesses workflow performance on given tasks. Provides feedback to guide the optimization process towards more effective workflows. See `metagpt/ext/aflow/scripts/evaluator.py` for details.
+
+
+
+
+
+## Datasets
+
+### Experimental Datasets
+We conducted experiments on six datasets (HumanEval, MBPP, GSM8K, MATH, HotpotQA, DROP) and provide their evaluation code. The data can be found in this [datasets](https://drive.google.com/uc?export=download&id=1DNoegtZiUhWtvkd2xoIuElmIi4ah7k8e) link, or you can download them using `metagpt/ext/aflow/data/download_data.py`
+
+
+
+
+
+### Custom Datasets
+For custom tasks, you can reference the code in the `metagpt/ext/aflow/benchmark` folder. Inherit the `BaseBenchmark` class and implement `evaluate_problem`, `calculate_score`, and `get_result_columns` to add your custom dataset benchmark. Then, add your benchmark name in `metagpt/ext/aflow/scripts/evaluator.py` and `metagpt/ext/aflow/scripts/optimizer.py` to find effective workflows for your custom dataset.
+
+## Quick Start
+
+1. Configure optimization parameters:
+ - Use command line arguments or modify default parameters in `examples/aflow/optimize.py`:
+ ```python
+ --dataset # (Required) Dataset type (HumanEval/MBPP/GSM8K/MATH/HotpotQA/DROP)
+ --sample 4 # Sample count - number of workflows to be resampled
+ --optimized_path PATH # Optimized result save path
+ --initial_round 1 # Initial round
+ --max_rounds 20 # Max iteration rounds for AFLOW
+ --check_convergence # Whether to enable early stop
+ --validation_rounds 5 # Validation rounds for AFLOW
+ --if_first_optimize # Set True for first optimization, False afterwards
+ ```
+
+2. Configure LLM parameters in `config/config2.yaml` (see `examples/aflow/config2.example.yaml` for reference)
+
+3. Set up operators in `optimize.py` and in `optimized_path/template/operator.py`, `optimized_path/template/operator.json`. You can reference our implementation to add operators for specific datasets
+
+4. For first-time use, download datasets and initial rounds by setting `download(["datasets", "initial_rounds"])` in `examples/aflow/optimize.py`
+
+5. (Optional) Add your custom dataset and corresponding evaluation function following the [Custom Datasets](#custom-datasets) section
+
+6. (Optional) If you want to use a portion of the validation data, you can set `va_list` in `examples/aflow/evaluator.py`
+
+7. Run the optimization:
+ ```bash
+ # Using default parameters
+ python -m examples.aflow.optimize --dataset MATH
+
+ # Or with custom parameters
+ python -m examples.aflow.optimize --dataset MATH --sample n --optimized_path xxx ...
+ ```
+
+## Reproduce the Results in the Paper
+1. We provide the raw data obtained from our experiments in this [link](https://drive.google.com/uc?export=download&id=1Sr5wjgKf3bN8OC7G6cO3ynzJqD4w6_Dv), including the workflows and prompts generated in each iteration, as well as their trajectories on the validation dataset. We also provide the optimal workflow for each dataset and the corresponding data on the test dataset. You can download these data using `metagpt/ext/aflow/data/download_data.py`.
+2. You can directly reproduce our experimental results by use different `ExperimentConfig` of `examples/aflow/optimize.py`.
+
+
+## Citation
+
+If you use AFlow in your research, please cite our paper:
+
+```
+@misc{zhang2024aflow,
+ title={AFlow: Automating Agentic Workflow Generation},
+ author={Jiayi Zhang and Jinyu Xiang and Zhaoyang Yu and Fengwei Teng and Xionghui Chen and Jiaqi Chen and Mingchen Zhuge and Xin Cheng and Sirui Hong and Jinlin Wang and Bingnan Zheng and Bang Liu and Yuyu Luo and Chenglin Wu},
+ year={2024},
+ eprint={2410.10762},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI},
+ url={https://arxiv.org/abs/2410.10762},
+}
+```
\ No newline at end of file
diff --git a/examples/aflow/config2.example.yaml b/examples/aflow/config2.example.yaml
new file mode 100644
index 000000000..ebaef33e2
--- /dev/null
+++ b/examples/aflow/config2.example.yaml
@@ -0,0 +1,12 @@
+models:
+ "": # model: "gpt-4-turbo" # or gpt-3.5-turbo
+ api_type: "openai" # or azure / ollama / groq etc.
+ base_url: ""
+ api_key: ""
+ temperature: 0
+ "":
+ api_type: "openai"
+ base_url: ""
+ api_key: ""
+ temperature: 0
+CALC_USAGE: True
diff --git a/examples/aflow/optimize.py b/examples/aflow/optimize.py
new file mode 100644
index 000000000..d07eab993
--- /dev/null
+++ b/examples/aflow/optimize.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+# @Date : 8/23/2024 20:00 PM
+# @Author : didi
+# @Desc : Entrance of AFlow.
+
+import argparse
+from typing import Dict, List
+
+from metagpt.configs.models_config import ModelsConfig
+from metagpt.ext.aflow.data.download_data import download
+from metagpt.ext.aflow.scripts.optimizer import Optimizer
+
+
+class ExperimentConfig:
+ def __init__(self, dataset: str, question_type: str, operators: List[str]):
+ self.dataset = dataset
+ self.question_type = question_type
+ self.operators = operators
+
+
+EXPERIMENT_CONFIGS: Dict[str, ExperimentConfig] = {
+ "DROP": ExperimentConfig(
+ dataset="DROP",
+ question_type="qa",
+ operators=["Custom", "AnswerGenerate", "ScEnsemble"],
+ ),
+ "HotpotQA": ExperimentConfig(
+ dataset="HotpotQA",
+ question_type="qa",
+ operators=["Custom", "AnswerGenerate", "ScEnsemble"],
+ ),
+ "MATH": ExperimentConfig(
+ dataset="MATH",
+ question_type="math",
+ operators=["Custom", "ScEnsemble", "Programmer"],
+ ),
+ "GSM8K": ExperimentConfig(
+ dataset="GSM8K",
+ question_type="math",
+ operators=["Custom", "ScEnsemble", "Programmer"],
+ ),
+ "MBPP": ExperimentConfig(
+ dataset="MBPP",
+ question_type="code",
+ operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"],
+ ),
+ "HumanEval": ExperimentConfig(
+ dataset="HumanEval",
+ question_type="code",
+ operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"],
+ ),
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="AFlow Optimizer")
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ choices=list(EXPERIMENT_CONFIGS.keys()),
+ required=True,
+ help="Dataset type",
+ )
+ parser.add_argument("--sample", type=int, default=4, help="Sample count")
+ parser.add_argument(
+ "--optimized_path",
+ type=str,
+ default="metagpt/ext/aflow/scripts/optimized",
+ help="Optimized result save path",
+ )
+ parser.add_argument("--initial_round", type=int, default=1, help="Initial round")
+ parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds")
+ parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop")
+ parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds")
+ parser.add_argument(
+ "--if_first_optimize",
+ type=lambda x: x.lower() == "true",
+ default=True,
+ help="Whether to download dataset for the first time",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize)
+ config = EXPERIMENT_CONFIGS[args.dataset]
+
+ mini_llm_config = ModelsConfig.default().get("gpt-4o-mini")
+ claude_llm_config = ModelsConfig.default().get("claude-3-5-sonnet-20240620")
+
+ optimizer = Optimizer(
+ dataset=config.dataset,
+ question_type=config.question_type,
+ opt_llm_config=claude_llm_config,
+ exec_llm_config=mini_llm_config,
+ check_convergence=args.check_convergence,
+ operators=config.operators,
+ optimized_path=args.optimized_path,
+ sample=args.sample,
+ initial_round=args.initial_round,
+ max_rounds=args.max_rounds,
+ validation_rounds=args.validation_rounds,
+ )
+
+ # Optimize workflow via setting the optimizer's mode to 'Graph'
+ optimizer.optimize("Graph")
+
+ # Test workflow via setting the optimizer's mode to 'Test'
+ # optimizer.optimize("Test")
diff --git a/examples/di/InfiAgent-DABench/DABench.py b/examples/di/InfiAgent-DABench/DABench.py
new file mode 100644
index 000000000..50ec04b29
--- /dev/null
+++ b/examples/di/InfiAgent-DABench/DABench.py
@@ -0,0 +1,487 @@
+import asyncio
+import json
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import nest_asyncio
+
+from examples.di.requirements_prompt import DABENCH
+from metagpt.const import DABENCH_PATH
+from metagpt.logs import logger
+from metagpt.utils.exceptions import handle_exception
+
+
+def evaluate_accuracy_by_question(results: dict) -> float:
+ """
+ Calculate the accuracy of results based on complete correctness of each question.
+ This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
+ This function checks whether each result is entirely correct, meaning all sub-questions
+ within that result are answered correctly. It computes the proportion of correct results
+ by dividing the number of fully correct results by the total number of results.
+
+ Args:
+ results (dict): A collection of results where each result may contain a 'correctness' field.
+
+ Returns:
+ float: The proportion of correct results, rounded to four decimal places.
+ Returns 0 if there are no results.
+ """
+ correct = sum("correctness" in result and all(result["correctness"].values()) for result in results)
+ total = len(results)
+ return round(correct / total, 4) if total > 0 else 0
+
+
+def evaluate_accuracy_by_sub_question(results: dict) -> float:
+ """
+ Evaluate the correctness of all sub-questions across the results.
+ This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
+ This function calculates the total number of correct sub-questions and the overall
+ number of sub-questions present in all results. It returns the ratio of correct
+ sub-questions to the total number of sub-questions.
+
+ Args:
+ results (dict): A collection of results where each result may contain a 'correctness' field.
+
+ Returns:
+ float: The ratio of correct sub-questions, rounded to four decimal places.
+ Returns 0 if there are no sub-questions.
+ """
+ correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result)
+ total = sum(len(result["correctness"]) for result in results if "correctness" in result)
+ return round(correct / total, 4) if total > 0 else 0
+
+
+def evaluate_accuracy_proportional_by_sub_question_adjusted(results: dict) -> float:
+ """
+ Adjust the score based on the number of sub-questions in each result.
+ This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
+ This function calculates a score for each result by considering the number of sub-questions
+ it contains. Each sub-question is assigned a score of 1 divided by the number of sub-questions.
+ The total score for each result is computed as the sum of all correct sub-questions multiplied
+ by the score per sub-question. Finally, it returns the average score across all results.
+
+ Args:
+ results (dict): A collection of results where each result may contain a 'correctness' field.
+
+ Returns:
+ float: The average score across all results, rounded to four decimal places.
+ Returns 0 if there are no results.
+ """
+ total_score = 0
+ for result in results:
+ if "correctness" in result:
+ sub_question_count = len(result["correctness"])
+ score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0
+ question_score = sum(result["correctness"].values()) * score_per_sub_question
+ total_score += question_score
+ return round(total_score / len(results), 4) if results else 0
+
+
+async def reformat(question: str, format: str, response: str) -> str:
+ """
+ Asynchronously reformats a given response based on specified formatting requirements.
+ This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
+ This function constructs a prompt for the LLM (Large Language Model) to reformat
+ the provided response according to the specified format. It includes a system prompt
+ to guide the LLM's behavior and a template that outlines the expected output structure.
+
+ Args:
+ question (str): The original question posed by the user.
+ format (str): The specific formatting requirements that the response must adhere to.
+ response (str): The initial response from the LLM that needs to be reformatted.
+
+ Returns:
+ str: The reformatted response generated by the LLM based on the provided question
+ and formatting requirements.
+ """
+ system_prompt = "You are a helpful assistant."
+ demons = """\Format{{
+ @shapiro_wilk_statistic[test_statistic]
+ @shapiro_wilk_p_value[p_value]
+ where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places.
+ where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places.
+ }}
+ \Answer{{
+ @shapiro_wilk_statistic[0.56]
+ @shapiro_wilk_p_value[0.0002]
+ }}
+
+ \Format{{
+ @total_votes_outliers_num[outlier_num]
+ where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column.
+ }}
+ \Answer{{
+ @total_votes_outliers[10]
+ }}
+ """
+ reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}.
+ Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes.
+ The format requirements of this question is:
+ {format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:"""
+ messages = [
+ {"role": "user", "content": question},
+ {"role": "assistant", "content": response},
+ {"role": "user", "content": reformat_template.format(demons=demons, format=format)},
+ ]
+ rsp = await ask(messages, system_prompt)
+ return rsp
+
+
+def load_jsonl(file_path: Union[Path, str]) -> List[Dict[str, Any]]:
+ """
+ Load data from a JSONL file into a list of dictionaries.
+
+ Args:
+ file_path (Union[Path, str]): The path to the JSONL file to be loaded.
+
+ Returns:
+ List[Dict[str, Any]]: A list of dictionaries containing the data from the JSONL file.
+ """
+ # Convert file_path to Path if it's a string
+ if isinstance(file_path, str):
+ file_path = Path(file_path)
+
+ data = []
+ with open(file_path, "r", encoding="utf-8") as file:
+ for line in file:
+ data.append(json.loads(line))
+ return data
+
+
+def compare_predictions(pred_dict: dict, true_label: list) -> bool:
+ """
+ Compares each prediction against the corresponding true label.
+
+ This function checks whether the predicted values match the true values for each
+ metric. It sorts the true labels to ensure the comparison is made in the correct
+ order. The function returns True if all predictions are accurate within a small
+ tolerance for numerical values, or if string values match case-insensitively.
+
+ Args:
+ pred_dict (dict): A dictionary of predicted metrics and their values.
+ true_label (list): A list of tuples containing true metrics and their values.
+
+ Returns:
+ bool: True if all predictions match the true labels, False otherwise.
+ """
+ sorted_true_label = sorted(true_label, key=lambda x: x[0]) # Sort true labels by metric name
+
+ for metric, true_value in sorted_true_label:
+ try:
+ true_value = float(true_value) # Attempt to convert the true value to float
+ except ValueError:
+ true_value = true_value.replace(",", "") # Clean the true value if conversion fails
+
+ # Check if the true value is numeric and compare with the prediction
+ if isinstance(true_value, (int, float)) and (
+ metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
+ ):
+ return False # Return False if the prediction is inaccurate
+
+ # Check if the true value is a string and compare with the prediction
+ if isinstance(true_value, str) and (
+ metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
+ ):
+ return False # Return False if the string prediction does not match
+
+ return True # Return True if all predictions are accurate
+
+
+async def ask(question: str, system_prompt: str) -> str:
+ """
+ Asynchronously sends a question to the LLM (Large Language Model) and retrieves the response.
+
+ This function initializes an instance of the LLM and uses it to ask a question
+ along with a system prompt. The response from the LLM is awaited and returned.
+
+ Args:
+ question (str): The question to be asked to the LLM.
+ system_prompt (str): A prompt that provides context or instructions to the LLM.
+
+ Returns:
+ str: The response from the LLM based on the provided question and system prompt.
+ """
+ from metagpt.llm import LLM # Importing the LLM class from the metagpt module
+
+ llm = LLM() # Create an instance of the LLM
+ rsp = await llm.aask(question, system_msgs=[system_prompt]) # Await the response from the LLM
+ return rsp # Return the response
+
+
+def parse_prediction(prediction: str) -> dict:
+ """
+ Parses a prediction string into a dictionary of metric-value pairs.
+
+ This function takes a formatted string containing metrics and their corresponding
+ values, separated by the "@" symbol. Each metric may be enclosed in brackets and
+ may include commas. The function processes the input to extract and clean the
+ metrics and their values, returning them in a structured dictionary format.
+
+ Args:
+ prediction (str): A string representation of metrics and their values.
+
+ Returns:
+ dict: A dictionary where each key is a metric name and each value is the
+ corresponding value, either as a float or a string.
+ """
+ pred_dict = {}
+ for pred in prediction.split("@"):
+ if pred == "":
+ continue # Skip any empty segments resulting from the split
+ temp = re.split(r"[\[\]]", pred.strip()) # Split the string by brackets
+ temp = [s.replace(",", "") for s in temp] # Remove commas from the segments
+ parts = [s for s in temp if s] # Filter out any empty strings
+ metric = parts[0].strip().replace(",", "") # Extract and clean the metric name
+ value = parts[-1].replace(",", "").replace(":", "") # Extract and clean the value
+
+ try:
+ value = float(value) # Attempt to convert the value to a float
+ except ValueError:
+ pass # If conversion fails, retain the value as a string
+
+ pred_dict[metric] = value # Store the metric-value pair in the dictionary
+ return pred_dict
+
+
+class DABench:
+ def __init__(
+ self,
+ questions_file: Path = Path(DABENCH_PATH) / "da-dev-questions.jsonl",
+ answers_file: Path = Path(DABENCH_PATH) / "da-dev-labels.jsonl",
+ template: str = "",
+ ):
+ """
+ Initializes the DABench instance with questions and answers.
+
+ This constructor loads questions and answers from specified JSONL files.
+ It also sets a template for formatting prompts. If no template is provided,
+ a default template is used.
+
+ Args:
+ questions_file (Path): The path to the JSONL file containing questions.
+ answers_file (Path): The path to the JSONL file containing answers.
+ template (str): A string template for formatting prompts.
+ """
+
+ self.questions = {
+ int(line["id"]): line for line in load_jsonl(questions_file)
+ } # Load questions from the specified file
+ self.answers = {
+ int(line["id"]): line for line in load_jsonl(answers_file)
+ } # Load answers from the specified file
+ self.template = template if template else DABENCH # Set the template, defaulting if necessary
+
+ def get_question(self, question_id: str) -> dict:
+ """
+ Retrieve the question associated with the given ID.
+
+ This method looks up a question by its unique identifier. If the question
+ is found, it returns the question data; otherwise, it returns a message
+ indicating that the question was not found.
+
+ Args:
+ question_id (str): The unique identifier for the question.
+
+ Returns:
+ dict: The question data if found, otherwise a "Question not found." message.
+ """
+ return self.questions.get(question_id, "Question not found.") # Return the question or an error message
+
+ def generate_formatted_prompt(self, question_id: str) -> str:
+ """
+ Generate a formatted prompt for the specified question ID.
+
+ This method retrieves the question data and formats it using the specified
+ template. The formatted prompt includes the question, constraints, format,
+ file name, and level, allowing for a structured output.
+
+ Args:
+ question_id (str): The unique identifier for the question.
+
+ Returns:
+ str: A formatted prompt string based on the question data.
+ """
+ temp = self.get_question(question_id) # Retrieve the question data
+ return self.template.format(
+ question=temp["question"],
+ constraints=temp["constraints"],
+ format=temp["format"],
+ file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
+ level=temp["level"],
+ ) # Format and return the prompt
+
+ def get_answer(self, answer_id: str) -> list:
+ """
+ Retrieve the answer list associated with the given ID.
+
+ This method looks up an answer by its unique identifier. If the answer
+ is found, it returns the answer data; otherwise, it returns a message
+ indicating that the answer was not found.
+
+ Args:
+ answer_id (str): The unique identifier for the answer.
+
+ Returns:
+ list: The answer data if found, otherwise an "Answer not found." message.
+ """
+ return self.answers.get(answer_id, "Answer not found.") # Return the answer or an error message
+
+ @handle_exception(exception_msg="Error parsing cleaned prediction", default_return=(None, False))
+ def parse_cleaned_prediction(self, cleaned_prediction: str, true_label: Any) -> Tuple[str, bool]:
+ """
+ Parse the cleaned prediction and compare it with the true label.
+
+ Args:
+ cleaned_prediction (str): The cleaned prediction string.
+ true_label (Any): The true label to compare against.
+
+ Returns:
+ Tuple[str, bool]: A tuple containing the cleaned prediction and a boolean indicating
+ whether it matches the true label.
+ """
+ if cleaned_prediction: # Ensure the cleaned prediction is not empty
+ pred_dict = parse_prediction(cleaned_prediction) # Parse the prediction
+ if pred_dict is not None and compare_predictions(pred_dict, true_label):
+ return cleaned_prediction, True # Return if the prediction matches the true label
+ return cleaned_prediction, False # Return the cleaned prediction with a False match
+
+ @handle_exception(exception_msg="Error during async reformat", default_return=(None, False))
+ def async_reformat_prediction(self, id: str, result: str) -> str:
+ """
+ Reformat the prediction asynchronously and extract the answer.
+
+ Args:
+ id (str): The identifier for the question.
+ result (str): The original prediction result.
+
+ Returns:
+ str: The reformatted prediction or the original prediction if extraction fails.
+ """
+ question = self.get_question(id)["question"] # Retrieve the question based on the ID
+ question_format = self.get_question(id)["format"] # Get the format of the question
+ prediction = asyncio.run(reformat(question, question_format, result)) # Asynchronously reformat the prediction
+
+ # Attempt to extract the answer from the reformatted prediction
+ answer_part = prediction.split("Answer{{") if "Answer{{" in prediction else []
+ if len(answer_part) > 1:
+ return answer_part[1].split("}}")[0].strip() # Return the extracted answer
+
+ return prediction # If extraction fails, return the original prediction
+
+ def eval(self, id: str, result: str) -> Tuple[str, bool]:
+ """
+ Evaluate the prediction against the true label.
+
+ Args:
+ id (str): The identifier for the question.
+ result (str): The original prediction result.
+
+ Returns:
+ Tuple[str, bool]: A tuple containing the final prediction and a boolean indicating
+ whether it matches the true label.
+ """
+ true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for comparison
+ nest_asyncio.apply() # Apply nested asyncio to allow for async calls
+ result = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"].strip()
+ cleaned_prediction = result.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string
+
+ # Use the decorated function to handle exceptions while parsing the cleaned prediction
+ parsed_result = self.parse_cleaned_prediction(cleaned_prediction, true_label)
+ if parsed_result[1]: # If the parsed prediction is valid
+ return parsed_result # Return the valid prediction
+
+ # If the cleaned prediction is not valid, attempt to asynchronously reformat it
+ prediction = self.async_reformat_prediction(id, result)
+
+ pred_dict = parse_prediction(prediction) # Parse the reformatted prediction
+ if pred_dict is not None and compare_predictions(pred_dict, true_label):
+ return prediction, True # Return if the reformatted prediction matches the true label
+
+ return prediction, False # Return the final prediction with a False match
+
+ @handle_exception(exception_msg="Error evaluating single prediction", default_return={})
+ def single_eval(self, id: str, prediction: str) -> dict:
+ """
+ Evaluate the prediction against the true label for a single question.
+ just using in eval_all
+
+ Args:
+ id (str): The identifier for the question.
+ prediction (str): The prediction string to evaluate.
+
+ Returns:
+ dict: A dictionary indicating the correctness of each metric.
+ """
+ true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for the question
+ prediction = prediction.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string
+ pred_dict = parse_prediction(prediction) # Parse the prediction into a dictionary
+
+ # Initialize the correctness dictionary with False values for each metric
+ correctness = {metric: False for metric, _ in true_label}
+
+ # Check each metric's prediction against the true label
+ for metric, true_value in true_label:
+ try:
+ true_value = float(true_value) # Attempt to convert the true value to float
+ except ValueError:
+ true_value = true_value.replace(",", "") # Handle non-numeric values
+
+ if metric in pred_dict:
+ # Consider the prediction correct if it's within a small tolerance
+ if (
+ isinstance(true_value, (int, float))
+ and isinstance(pred_dict[metric], (int, float))
+ and abs(pred_dict[metric] - true_value) < 1e-6
+ ):
+ correctness[metric] = True # Mark as correct if within tolerance
+
+ if isinstance(true_value, str) and (
+ metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
+ ):
+ correctness[metric] = True # Mark as correct for string comparison
+
+ return correctness # Return the correctness dictionary
+
+ def eval_all(self, id_list: list, predictions: list) -> dict:
+ """
+ Evaluate all predictions and calculate accuracy rates.
+
+ Args:
+ id_list (list): A list of question identifiers.
+ predictions (list): A list of prediction strings corresponding to the questions.
+
+ Returns:
+ dict: A dictionary containing accuracy rates by question and sub-question.
+ """
+ results = [] # Initialize a list to store results for each question
+
+ # Evaluate each prediction against its corresponding question ID
+ for id, prediction in zip(id_list, predictions):
+ correct = self.single_eval(id, prediction) # Evaluate the single prediction
+ results.append({"id": id, "correctness": correct}) # Append the result to the list
+
+ # Calculate the three accuracy rates based on the results
+ accuracy_by_question = evaluate_accuracy_by_question(results)
+ accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
+ proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)
+
+ return {
+ "accuracy_by_question": accuracy_by_question,
+ "accuracy_by_sub_question": accuracy_by_sub_question,
+ "proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
+ }
+
+
+if __name__ == "__main__":
+ bench = DABench()
+ id = 0
+ prediction = "@mean_fare[34.65]"
+ logger.info(bench.eval(id, prediction))
+ ids = [0, 5, 6]
+ predictions = [
+ "@mean_fare[34.89]",
+ "@correlation_coefficient[0.21]",
+ "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
+ ]
+ logger.info(bench.eval_all(ids, predictions))
diff --git a/examples/di/InfiAgent-DABench/README.md b/examples/di/InfiAgent-DABench/README.md
new file mode 100644
index 000000000..74783c9d1
--- /dev/null
+++ b/examples/di/InfiAgent-DABench/README.md
@@ -0,0 +1,45 @@
+# InfiAgent-DABench
+This example is used to solve the InfiAgent-DABench using Data Interpreter (DI), and obtains 94.93% accuracy using gpt-4o.
+
+## Dataset download
+```
+cd /examples/di/InfiAgent-DABench
+git clone https://github.com/InfiAgent/InfiAgent.git
+mv InfiAgent/examples/DA-Agent/data ./
+```
+## Special note:
+When doing DABench testing, you need to set the ExecuteNbCode() init to:
+```
+class ExecuteNbCode(Action):
+ """execute notebook code block, return result to llm, and display it."""
+
+ nb: NotebookNode
+ nb_client: NotebookClient
+ console: Console
+ interaction: str
+ timeout: int = 600
+
+ def __init__(
+ self,
+ nb=nbformat.v4.new_notebook(),
+ timeout=600,
+ ):
+ super().__init__(
+ nb=nbformat.v4.new_notebook(),#nb,
+ nb_client=NotebookClient(nb, timeout=timeout),
+ timeout=timeout,
+ console=Console(),
+ interaction=("ipython" if self.is_ipython() else "terminal"),
+ )
+```
+The path of ExecuteNbCode() is:
+```
+metagpt.actions.di.execute_nb_code
+```
+Instead of using the original nb initialization by default.
+## How to run
+```
+python run_InfiAgent-DABench_single.py --id x # run a task, x represents the id of the question you want to test
+python run_InfiAgent-DABench_all.py # Run all tasks serially
+python run_InfiAgent-DABench.py --k x # Run all tasks in parallel, x represents the number of parallel tasks at a time
+```
\ No newline at end of file
diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py
new file mode 100644
index 000000000..7e1fbad8b
--- /dev/null
+++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench.py
@@ -0,0 +1,77 @@
+import asyncio
+import json
+
+from DABench import DABench
+
+from metagpt.logs import logger
+from metagpt.roles.di.data_interpreter import DataInterpreter
+
+
+async def get_prediction(agent, requirement):
+ """Helper function to obtain a prediction from a new instance of the agent.
+
+ This function runs the agent with the provided requirement and extracts the prediction
+ from the result. If an error occurs during processing, it logs the error and returns None.
+
+ Args:
+ agent: The agent instance used to generate predictions.
+ requirement: The input requirement for which the prediction is to be made.
+
+ Returns:
+ The predicted result if successful, otherwise None.
+ """
+ try:
+ # Run the agent with the given requirement and await the result
+ result = await agent.run(requirement)
+
+ # Parse the result to extract the prediction from the JSON response
+ prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])
+ prediction = prediction_json[-1]["result"] # Extract the last result from the parsed JSON
+
+ return prediction # Return the extracted prediction
+ except Exception as e:
+ # Log an error message if an exception occurs during processing
+ logger.info(f"Error processing requirement: {requirement}. Error: {e}")
+ return None # Return None in case of an error
+
+
+async def evaluate_all(agent, k):
+ """Evaluate all tasks in DABench using the specified baseline agent.
+
+ Tasks are divided into groups of size k and processed in parallel.
+
+ Args:
+ agent: The baseline agent used for making predictions.
+ k (int): The number of tasks to process in each group concurrently.
+ """
+ bench = DABench() # Create an instance of DABench to access its methods and data
+ id_list, predictions = [], [] # Initialize lists to store IDs and predictions
+ tasks = [] # Initialize a list to hold the tasks
+
+ # Iterate over the answers in DABench to generate tasks
+ for key, value in bench.answers.items():
+ requirement = bench.generate_formatted_prompt(key) # Generate a formatted prompt for the current key
+ tasks.append(get_prediction(agent, requirement)) # Append the prediction task to the tasks list
+ id_list.append(key) # Append the current key to the ID list
+
+ # Process tasks in groups of size k and execute them concurrently
+ for i in range(0, len(tasks), k):
+ # Get the current group of tasks
+ current_group = tasks[i : i + k]
+ # Execute the current group of tasks in parallel
+ group_predictions = await asyncio.gather(*current_group)
+ # Filter out any None values from the predictions and extend the predictions list
+ predictions.extend(pred for pred in group_predictions if pred is not None)
+
+ # Evaluate the results using all valid predictions and logger.info the evaluation
+ logger.info(bench.eval_all(id_list, predictions))
+
+
+def main(k=5):
+ """Main function to run the evaluation process."""
+ agent = DataInterpreter() # Create an instance of the DataInterpreter agent
+ asyncio.run(evaluate_all(agent, k)) # Run the evaluate_all function asynchronously
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py
new file mode 100644
index 000000000..5cd1ef4b0
--- /dev/null
+++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_all.py
@@ -0,0 +1,35 @@
+import fire
+import pandas as pd
+from DABench import DABench
+
+from metagpt.logs import logger
+from metagpt.roles.di.data_interpreter import DataInterpreter
+from metagpt.utils.recovery_util import save_history
+
+
+async def main():
+ """Evaluate all"""
+ bench = DABench()
+ id_list, predictions, labels, is_true = [], [], [], []
+ for key, value in bench.answers.items():
+ id_list.append(key)
+ labels.append(str(bench.get_answer(key)))
+ try:
+ requirement = bench.generate_formatted_prompt(key)
+ di = DataInterpreter()
+ result = await di.run(requirement)
+ logger.info(result)
+ save_history(role=di)
+ temp_prediction, temp_istrue = bench.eval(key, str(result))
+ is_true.append(str(temp_istrue))
+ predictions.append(str(temp_prediction))
+ except:
+ is_true.append(str(bench.eval(key, "")))
+ predictions.append(str(""))
+ df = pd.DataFrame({"Label": labels, "Prediction": predictions, "T/F": is_true})
+ df.to_excel("DABench_output.xlsx", index=False)
+ logger.info(bench.eval_all(id_list, predictions))
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_single.py b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_single.py
new file mode 100644
index 000000000..470f12fc8
--- /dev/null
+++ b/examples/di/InfiAgent-DABench/run_InfiAgent-DABench_single.py
@@ -0,0 +1,22 @@
+import fire
+from DABench import DABench
+
+from metagpt.logs import logger
+from metagpt.roles.di.data_interpreter import DataInterpreter
+from metagpt.utils.recovery_util import save_history
+
+
+async def main(id=0):
+ """Evaluate one task"""
+ bench = DABench()
+ requirement = bench.generate_formatted_prompt(id)
+ di = DataInterpreter()
+ result = await di.run(requirement)
+ logger.info(result)
+ save_history(role=di)
+ _, is_correct = bench.eval(id, str(result))
+ logger.info(f"Prediction is {'correct' if is_correct else 'incorrect'}.")
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/examples/di/requirements_prompt.py b/examples/di/requirements_prompt.py
index 04a0414b1..34102c134 100644
--- a/examples/di/requirements_prompt.py
+++ b/examples/di/requirements_prompt.py
@@ -1,3 +1,5 @@
+# InfiAgent-DABench requirements
+DABENCH = "You are required to {question} from a CSV file named {file_name}. **Constraints**: Ensure that {constraints}, which must be strictly followed throughout the task. The output format should be {format}. This task is categorized as {level}."
# ML-Benchmark requirements
IRIS_REQ = "Run data analysis on sklearn Iris dataset, include a plot"
WINES_RECOGNITION_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class with 20% as test set, and show prediction accuracy"
diff --git a/examples/llm_vision.py b/examples/llm_vision.py
index 276decd59..eea6550f6 100644
--- a/examples/llm_vision.py
+++ b/examples/llm_vision.py
@@ -15,8 +15,8 @@ async def main():
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
- res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
- assert "true" in res.lower()
+ res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64])
+ assert ("true" in res.lower()) or ("invoice" in res.lower())
if __name__ == "__main__":
diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py
index ad3f0a1d1..42166dbd1 100644
--- a/metagpt/actions/action_node.py
+++ b/metagpt/actions/action_node.py
@@ -9,6 +9,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
"""
import json
+import re
import typing
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union
@@ -23,6 +24,7 @@ from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from metagpt.utils.common import OutputParser, general_after_log
from metagpt.utils.human_interaction import HumanInteraction
+from metagpt.utils.sanitize import sanitize
class ReviewMode(Enum):
@@ -38,9 +40,17 @@ class ReviseMode(Enum):
TAG = "CONTENT"
+
+class FillMode(Enum):
+ CODE_FILL = "code_fill"
+ XML_FILL = "xml_fill"
+ SINGLE_FILL = "single_fill"
+
+
LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT."
FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else."
+
SIMPLE_TEMPLATE = """
## context
{context}
@@ -471,6 +481,116 @@ class ActionNode:
return self
+ def get_field_name(self):
+ """
+ Get the field name from the Pydantic model associated with this ActionNode.
+ """
+ model_class = self.create_class()
+ fields = model_class.model_fields
+
+ # Assuming there's only one field in the model
+ if len(fields) == 1:
+ return next(iter(fields))
+
+ # If there are multiple fields, we might want to use self.key to find the right one
+ return self.key
+
+ def get_field_names(self):
+ """
+ Get the field names associated with this ActionNode's Pydantic model.
+ """
+ model_class = self.create_class()
+ return model_class.model_fields.keys()
+
+ def get_field_types(self):
+ """
+ Get the field types associated with this ActionNode's Pydantic model.
+ """
+ model_class = self.create_class()
+ return {field_name: field.annotation for field_name, field in model_class.model_fields.items()}
+
+ def xml_compile(self, context):
+ """
+ Compile the prompt to make it easier for the model to understand the xml format.
+ """
+ field_names = self.get_field_names()
+ # Construct the example using the field names
+ examples = []
+ for field_name in field_names:
+ examples.append(f"<{field_name}>content{field_name}>")
+
+ # Join all examples into a single string
+ example_str = "\n".join(examples)
+ # Add the example to the context
+ context += f"""
+### Response format (must be strictly followed): All content must be enclosed in the given XML tags, ensuring each opening has a corresponding closing , with no incomplete or self-closing tags allowed.\n
+{example_str}
+"""
+ return context
+
+ async def code_fill(
+ self, context: str, function_name: Optional[str] = None, timeout: int = USE_CONFIG_TIMEOUT
+ ) -> Dict[str, str]:
+ """
+ Fill CodeBlock Using ``` ```
+ """
+ field_name = self.get_field_name()
+ prompt = context
+ content = await self.llm.aask(prompt, timeout=timeout)
+ extracted_code = sanitize(code=content, entrypoint=function_name)
+ result = {field_name: extracted_code}
+ return result
+
+ async def single_fill(self, context: str) -> Dict[str, str]:
+ field_name = self.get_field_name()
+ prompt = context
+ content = await self.llm.aask(prompt)
+ result = {field_name: content}
+ return result
+
+ async def xml_fill(self, context: str) -> Dict[str, Any]:
+ """
+ Fill context with XML tags and convert according to field types, including string, integer, boolean, list and dict types
+ """
+ field_names = self.get_field_names()
+ field_types = self.get_field_types()
+
+ extracted_data: Dict[str, Any] = {}
+ content = await self.llm.aask(context)
+
+ for field_name in field_names:
+ pattern = rf"<{field_name}>(.*?){field_name}>"
+ match = re.search(pattern, content, re.DOTALL)
+ if match:
+ raw_value = match.group(1).strip()
+ field_type = field_types.get(field_name)
+
+ if field_type == str:
+ extracted_data[field_name] = raw_value
+ elif field_type == int:
+ try:
+ extracted_data[field_name] = int(raw_value)
+ except ValueError:
+ extracted_data[field_name] = 0 # 或者其他默认值
+ elif field_type == bool:
+ extracted_data[field_name] = raw_value.lower() in ("true", "yes", "1", "on", "True")
+ elif field_type == list:
+ try:
+ extracted_data[field_name] = eval(raw_value)
+ if not isinstance(extracted_data[field_name], list):
+ raise ValueError
+ except:
+ extracted_data[field_name] = [] # 默认空列表
+ elif field_type == dict:
+ try:
+ extracted_data[field_name] = eval(raw_value)
+ if not isinstance(extracted_data[field_name], dict):
+ raise ValueError
+ except:
+ extracted_data[field_name] = {} # 默认空字典
+
+ return extracted_data
+
async def fill(
self,
context,
@@ -481,6 +601,7 @@ class ActionNode:
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
exclude=[],
+ function_name: str = None,
):
"""Fill the node(s) with mode.
@@ -507,6 +628,22 @@ class ActionNode:
if self.schema:
schema = self.schema
+ if mode == FillMode.CODE_FILL.value:
+ result = await self.code_fill(context, function_name, timeout)
+ self.instruct_content = self.create_class()(**result)
+ return self
+
+ elif mode == FillMode.XML_FILL.value:
+ context = self.xml_compile(context=self.context)
+ result = await self.xml_fill(context)
+ self.instruct_content = self.create_class()(**result)
+ return self
+
+ elif mode == FillMode.SINGLE_FILL.value:
+ result = await self.single_fill(context)
+ self.instruct_content = self.create_class()(**result)
+ return self
+
if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude)
elif strgy == "complex":
diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py
index 7388063aa..ef034ca49 100644
--- a/metagpt/configs/llm_config.py
+++ b/metagpt/configs/llm_config.py
@@ -5,6 +5,7 @@
@Author : alexanderwu
@File : llm_config.py
"""
+
from enum import Enum
from typing import Optional
@@ -25,7 +26,10 @@ class LLMType(Enum):
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE = "azure"
- OLLAMA = "ollama"
+ OLLAMA = "ollama" # /chat at ollama api
+ OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api
+ OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api
+ OLLAMA_EMBED = "ollama.embed" # /embed at ollama api
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
MOONSHOT = "moonshot"
@@ -57,6 +61,7 @@ class LLMConfig(YamlModel):
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None
secret_key: Optional[str] = None
+ session_token: Optional[str] = None
endpoint: Optional[str] = None # for self-deployed model on the cloud
# For Spark(Xunfei), maybe remove later
@@ -76,10 +81,12 @@ class LLMConfig(YamlModel):
best_of: Optional[int] = None
n: Optional[int] = None
stream: bool = True
+ seed: Optional[int] = None
# https://cookbook.openai.com/examples/using_logprobs
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
timeout: int = 600
+ context_length: Optional[int] = None # Max input tokens
# For Amazon Bedrock
region_name: str = None
@@ -101,7 +108,8 @@ class LLMConfig(YamlModel):
root_config_path = CONFIG_ROOT / "config2.yaml"
if root_config_path.exists():
raise ValueError(
- f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
+ f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n"
+ f"the former will overwrite the latter. This may cause unexpected result.\n"
)
elif repo_config_path.exists():
raise ValueError(f"Please set your API key in {repo_config_path}")
diff --git a/metagpt/const.py b/metagpt/const.py
index f33b46b68..9497fdd1e 100644
--- a/metagpt/const.py
+++ b/metagpt/const.py
@@ -43,6 +43,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
EXAMPLE_PATH = METAGPT_ROOT / "examples"
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
DATA_PATH = METAGPT_ROOT / "data"
+DABENCH_PATH = EXAMPLE_PATH / "di/InfiAgent-DABench/data"
EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm"
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
RESEARCH_PATH = DATA_PATH / "research"
diff --git a/metagpt/ext/aflow/benchmark/README.md b/metagpt/ext/aflow/benchmark/README.md
new file mode 100644
index 000000000..4a2464fd1
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/README.md
@@ -0,0 +1,29 @@
+# Custom Evaluation Function via Benchmark Class
+
+## How to Use
+
+To create a benchmark for a new dataset, follow these steps:
+
+1. Create a new Python file, e.g., `my_dataset_benchmark.py`
+2. Import the base class:
+ ```python
+ from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+ ```
+3. Create a new class that inherits from `BaseBenchmark`:
+ ```python
+ class MyDatasetBenchmark(BaseBenchmark):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ super().__init__(name, file_path, log_path)
+ ```
+4. Implement the required abstract methods:
+ - `evaluate_problem`: Evaluate a single problem
+ - `calculate_score`: Calculate the score for a prediction
+ - `get_result_columns`: Define column names for the results CSV file
+
+5. Override other methods as needed, such as `load_data` or `save_results_to_csv`
+
+## Example
+
+Refer to the `DROPBenchmark` class in the `drop.py` file for an example of how to implement a benchmark for a specific dataset.
+
+By following these guidelines, you can easily create custom benchmark evaluations for new datasets.
diff --git a/metagpt/ext/aflow/benchmark/benchmark.py b/metagpt/ext/aflow/benchmark/benchmark.py
new file mode 100644
index 000000000..b5692f01e
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/benchmark.py
@@ -0,0 +1,104 @@
+import asyncio
+import json
+import os
+from abc import ABC, abstractmethod
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Callable, List, Tuple
+
+import aiofiles
+import pandas as pd
+from tqdm.asyncio import tqdm_asyncio
+
+from metagpt.logs import logger
+from metagpt.utils.common import write_json_file
+
+
+class BaseBenchmark(ABC):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ self.name = name
+ self.file_path = file_path
+ self.log_path = log_path
+
+ PASS = "PASS"
+ FAIL = "FAIL"
+
+ async def load_data(self, specific_indices: List[int] = None) -> List[dict]:
+ data = []
+ async with aiofiles.open(self.file_path, mode="r", encoding="utf-8") as file:
+ async for line in file:
+ data.append(json.loads(line))
+ if specific_indices is not None:
+ filtered_data = [data[i] for i in specific_indices if i < len(data)]
+ return filtered_data
+ return data
+
+ def save_results_to_csv(self, results: List[Tuple[Any, ...]], columns: List[str]):
+ df = pd.DataFrame(results, columns=columns)
+ avg_score = df["score"].mean()
+ t_cost = df["cost"].max()
+ a_cost = t_cost / len(df) if len(df) > 0 else 0
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
+ filename = f"{avg_score:.5f}_{current_time}.csv"
+ output_file = os.path.join(self.log_path, filename)
+ df.to_csv(output_file, index=False)
+ logger.info(f"Results saved to {output_file}")
+ return avg_score, a_cost, t_cost
+
+ def log_mismatch(
+ self,
+ problem: str,
+ expected_output: Any,
+ prediction: str,
+ extracted_output: Any,
+ extract_answer_code: str = "None",
+ ):
+ log_data = {
+ "question": problem,
+ "right_answer": expected_output,
+ "model_output": prediction,
+ "extracted_output": extracted_output,
+ "extract_answer_code": extract_answer_code,
+ }
+ log_file = Path(self.log_path) / "log.json"
+ if log_file.exists():
+ with log_file.open("r", encoding="utf-8") as f:
+ try:
+ data = json.load(f)
+ except json.JSONDecodeError:
+ data = []
+ else:
+ data = []
+ data.append(log_data)
+ write_json_file(log_file, data, encoding="utf-8", indent=4)
+
+ @abstractmethod
+ async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[Any, ...]:
+ pass
+
+ @abstractmethod
+ def calculate_score(self, expected_output: Any, prediction: Any) -> Tuple[float, Any]:
+ pass
+
+ @abstractmethod
+ def get_result_columns(self) -> List[str]:
+ pass
+
+ async def evaluate_all_problems(self, data: List[dict], graph: Callable, max_concurrent_tasks: int = 50):
+ semaphore = asyncio.Semaphore(max_concurrent_tasks)
+
+ async def sem_evaluate(problem):
+ async with semaphore:
+ return await self.evaluate_problem(problem, graph)
+
+ tasks = [sem_evaluate(problem) for problem in data]
+ return await tqdm_asyncio.gather(*tasks, desc=f"Evaluating {self.name} problems", total=len(data))
+
+ async def run_evaluation(self, graph: Callable, va_list: List[int], max_concurrent_tasks: int = 50):
+ data = await self.load_data(va_list)
+ results = await self.evaluate_all_problems(data, graph, max_concurrent_tasks)
+ columns = self.get_result_columns()
+ average_score, average_cost, total_cost = self.save_results_to_csv(results, columns)
+ logger.info(f"Average score on {self.name} dataset: {average_score:.5f}")
+ logger.info(f"Total Cost: {total_cost:.5f}")
+ return average_score, average_cost, total_cost
diff --git a/metagpt/ext/aflow/benchmark/drop.py b/metagpt/ext/aflow/benchmark/drop.py
new file mode 100644
index 000000000..3cec5795f
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/drop.py
@@ -0,0 +1,83 @@
+import re
+import string
+from collections import Counter
+from typing import Callable, List, Tuple
+
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+from metagpt.logs import logger
+
+
+class DROPBenchmark(BaseBenchmark):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ super().__init__(name, file_path, log_path)
+
+ def normalize_answer(self, s: str) -> List[str]:
+ """
+ Normalize answers for evaluation.
+ """
+
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+ def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, str]:
+ """
+ Compute the F1 score between prediction and ground truth answers.
+ """
+ prediction_tokens = self.normalize_answer(prediction).split()
+ ground_truth_tokens = self.normalize_answer(ground_truth).split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0, prediction
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1, prediction
+
+ @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True)
+ async def _generate_output(self, graph, input_text):
+ return await graph(input_text)
+
+ async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, str, float, float]:
+ input_text = problem["context"]
+ expected_output = problem["ref_text"]
+ answers = expected_output.split("|")
+
+ try:
+ output, cost = await self._generate_output(graph, input_text)
+ f1_scores = []
+
+ for answer in answers:
+ if answer.strip() != "":
+ output_parts = output.split("|")
+ for output_part in output_parts:
+ f1_score, _ = self.calculate_score(answer, output_part)
+ f1_scores.append(f1_score)
+
+ uni_score = max(f1_scores)
+
+ if uni_score < 0.3:
+ self.log_mismatch(input_text, expected_output, output, output)
+
+ return input_text, output, expected_output, uni_score, cost
+
+ except Exception as e:
+ logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}")
+ return input_text, str(e), expected_output, 0.0, 0.0
+
+ def get_result_columns(self) -> List[str]:
+ return ["inputs", "prediction", "expected_output", "score", "cost"]
diff --git a/metagpt/ext/aflow/benchmark/gsm8k.py b/metagpt/ext/aflow/benchmark/gsm8k.py
new file mode 100644
index 000000000..51979c0c5
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/gsm8k.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+# @Date :
+# @Author : all
+# @Desc : test on gsm8k
+import re
+from typing import Callable, List, Optional, Tuple
+
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+from metagpt.logs import logger
+
+
+class GSM8KBenchmark(BaseBenchmark):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ super().__init__(name, file_path, log_path)
+
+ def extract_number(self, text: str) -> Optional[float]:
+ matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text))
+ if matches:
+ last_number = matches[-1].replace(",", "")
+ try:
+ return float(last_number)
+ except ValueError:
+ return None
+ else:
+ return None
+
+ def calculate_score(self, expected_output: float, prediction: float) -> Tuple[float, float]:
+ if prediction is None:
+ return 0.0, prediction
+ return 1.0 if abs(expected_output - prediction) <= 1e-6 else 0.0, prediction
+
+ @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True)
+ async def _generate_output(self, graph, input_text):
+ return await graph(input_text)
+
+ async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, float, float, float]:
+ input_text = problem["question"]
+ expected_output = self.extract_number(problem["answer"])
+
+ try:
+ output, cost = await self._generate_output(graph, input_text)
+ predicted_number = self.extract_number(output)
+ score, extracted_output = self.calculate_score(expected_output, predicted_number)
+
+ if score == 0:
+ self.log_mismatch(input_text, expected_output, output, extracted_output)
+
+ return input_text, output, expected_output, score, cost
+
+ except Exception as e:
+ logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}")
+ return input_text, str(e), expected_output, 0.0, 0.0
+
+ def get_result_columns(self) -> List[str]:
+ return ["question", "prediction", "expected_output", "score", "cost"]
diff --git a/metagpt/ext/aflow/benchmark/hotpotqa.py b/metagpt/ext/aflow/benchmark/hotpotqa.py
new file mode 100644
index 000000000..b3bafe22b
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/hotpotqa.py
@@ -0,0 +1,71 @@
+import re
+import string
+from collections import Counter
+from typing import Callable, List, Tuple
+
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+from metagpt.logs import logger
+
+
+class HotpotQABenchmark(BaseBenchmark):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ super().__init__(name, file_path, log_path)
+
+ def normalize_answer(self, s: str) -> str:
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+ def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, str]:
+ prediction_tokens = self.normalize_answer(prediction).split()
+ ground_truth_tokens = self.normalize_answer(ground_truth).split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0, prediction
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1, prediction
+
+ @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True)
+ async def _generate_output(self, graph, input_text):
+ return await graph(input_text)
+
+ async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, str, str, float, float]:
+ input_text = problem["question"]
+ expected_output = problem["answer"]
+ paragraphs = [item[1] for item in problem["context"] if isinstance(item[1], list)]
+ context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs)
+ inputs = f"Context: {context_str}\n\nQuestion: {input_text}\n\nAnswer:"
+
+ try:
+ output, cost = await self._generate_output(graph, inputs)
+ score, extracted_output = self.calculate_score(expected_output, output)
+
+ if (
+ score < 0.3
+ ): # We set the threshold for collecting incorrect questions to 0.3, as F1 Score cannot be simply judged using 0-1
+ self.log_mismatch(input_text, expected_output, output, extracted_output)
+
+ return input_text, context_str, output, expected_output, score, cost
+
+ except Exception as e:
+ logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}")
+ return input_text, context_str, str(e), expected_output, 0.0, 0.0
+
+ def get_result_columns(self) -> List[str]:
+ return ["question", "context", "prediction", "expected_output", "score", "cost"]
diff --git a/metagpt/ext/aflow/benchmark/humaneval.py b/metagpt/ext/aflow/benchmark/humaneval.py
new file mode 100644
index 000000000..b54add260
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/humaneval.py
@@ -0,0 +1,151 @@
+import asyncio
+import threading
+import time
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+from metagpt.logs import logger
+from metagpt.utils.sanitize import sanitize
+
+
+class HumanEvalBenchmark(BaseBenchmark):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ super().__init__(name, file_path, log_path)
+
+ class TimeoutError(Exception):
+ pass
+
+ def run_with_timeout(self, func, args, timeout):
+ result = []
+ stop_event = threading.Event()
+
+ def target():
+ try:
+ result.append(func(*args))
+ except Exception as e:
+ result.append(e)
+ finally:
+ stop_event.set()
+
+ thread = threading.Thread(target=target)
+ thread.start()
+ is_timeout = not stop_event.wait(timeout)
+
+ if is_timeout:
+ raise self.TimeoutError("Function execution timed out")
+
+ if not result:
+ return None
+ if isinstance(result[0], Exception):
+ raise result[0]
+ return result[0]
+
+ def check_solution(self, solution, test, entry_point):
+ solution = sanitize(code=solution, entrypoint=entry_point)
+ try:
+ global_dict = {
+ "math": __import__("math"),
+ "hashlib": __import__("hashlib"),
+ "re": __import__("re"),
+ "List": List,
+ "Dict": Dict,
+ "Tuple": Tuple,
+ "Optional": Optional,
+ "Any": Any,
+ }
+
+ # Add handling for special cases
+ if entry_point == "decode_cyclic":
+ solution = (
+ '\n\ndef encode_cyclic(s: str):\n """\n returns encoded string by cycling groups of three characters.\n """\n # split string to groups. Each of length 3.\n groups = [s[(3 * i):min((3 * i + 3), len(s))] for i in range((len(s) + 2) // 3)]\n # cycle elements in each group. Unless group has fewer elements than 3.\n groups = [(group[1:] + group[0]) if len(group) == 3 else group for group in groups]\n return "".join(groups)'
+ + "\n\n"
+ + solution
+ )
+ elif entry_point == "decode_shift":
+ solution = (
+ '\n\ndef encode_shift(s: str):\n """\n returns encoded string by shifting every character by 5 in the alphabet.\n """\n return "".join([chr(((ord(ch) + 5 - ord("a")) % 26) + ord("a")) for ch in s])\n\n\n'
+ + solution
+ )
+ elif entry_point == "find_zero":
+ solution = (
+ "\n\ndef poly(xs: list, x: float):\n return sum(coeff * (x ** i) for i, coeff in enumerate(xs))\n\n"
+ + solution
+ )
+
+ exec(solution, global_dict)
+
+ if entry_point not in global_dict:
+ raise ValueError(f"Function {entry_point} is not defined in the solution.")
+
+ exec(test, global_dict)
+
+ check = global_dict["check"]
+
+ result = self.run_with_timeout(check, (global_dict[entry_point],), 15)
+
+ if result is None:
+ result = (self.PASS, "The solution passed all test cases.")
+
+ except self.TimeoutError:
+ result = (
+ self.FAIL,
+ "Execution timed out. Please check if your solution contains infinite loops or overly time-consuming operations.",
+ )
+ except Exception as e:
+ error_message = f"Error: {str(e)}.\n Solution: {solution}.\n Test: {test}"
+ result = (self.FAIL, error_message)
+
+ with open("error.log", "a", encoding="utf-8") as log_file:
+ log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n")
+
+ return result
+
+ @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True)
+ async def _generate_output(self, graph, prompt, entry_point):
+ # Generate output with a timeout of 60 seconds
+ return await asyncio.wait_for(graph(prompt, entry_point), timeout=60)
+
+ async def evaluate_problem(self, data: dict, graph: Callable) -> Tuple[str, str, str, float, float]:
+ input_text = data["prompt"]
+ expected_output = (
+ "\nCorrect Solution:\ndef "
+ + data["entry_point"]
+ + "(params you should put here):"
+ + "\n\n"
+ + data["canonical_solution"]
+ )
+
+ try:
+ # Generate prediction using the graph function
+ prediction, cost = await self._generate_output(graph, input_text, data["entry_point"])
+
+ # Check the solution
+ ret = self.check_solution(prediction, data["test"], data["entry_point"])
+ test_case_details = ret[1]
+ expected_output = test_case_details + expected_output
+
+ # Calculate score based on the check result
+ score = 1.0 if ret[0] == self.PASS else 0.0
+
+ # Log mismatch if the score is 0
+ if score == 0:
+ self.log_mismatch(input_text, expected_output, prediction, score)
+
+ return input_text, prediction, expected_output, score, cost
+
+ except asyncio.TimeoutError:
+ logger.info("Timeout error. Skipping this sample.")
+ return input_text, "Timeout", expected_output, 0.0, 0.0
+
+ except Exception as e:
+ logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}")
+ return input_text, str(e), expected_output, 0.0, 0.0
+
+ def calculate_score(self, expected_output: str, prediction: str) -> Tuple[float, str]:
+ # The scoring logic for HumanEval is already implemented in evaluate_problem, this is just to conform to the interface
+ return 0.0, prediction
+
+ def get_result_columns(self) -> List[str]:
+ return ["inputs", "prediction", "expected_output", "score", "cost"]
diff --git a/metagpt/ext/aflow/benchmark/math.py b/metagpt/ext/aflow/benchmark/math.py
new file mode 100644
index 000000000..07b0612d0
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/math.py
@@ -0,0 +1,137 @@
+import inspect
+import re
+from math import isclose
+from typing import Any, Callable, List, Tuple
+
+import regex
+from sympy import N, simplify
+from sympy.parsing.latex import parse_latex
+from sympy.parsing.sympy_parser import parse_expr
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+from metagpt.logs import logger
+
+
+class MATHBenchmark(BaseBenchmark):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ super().__init__(name, file_path, log_path)
+
+ def extract_model_answer(self, text: str) -> str:
+ pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}"
+ boxed_matches = re.findall(pattern, text, re.DOTALL)
+ if boxed_matches:
+ return boxed_matches[-1].strip()
+
+ sentence_end_pattern = r"(? Tuple[int, str]:
+ expected_answer = self.extract_model_answer(expected_output)
+ predicted_answer = self.extract_model_answer(prediction)
+
+ if self.math_equal(predicted_answer, expected_answer):
+ return 1, predicted_answer
+ else:
+ return 0, predicted_answer
+
+ def math_equal(self, prediction: Any, reference: Any) -> bool:
+ if str(prediction) == str(reference):
+ return True
+
+ try:
+ if self.is_digit(prediction) and self.is_digit(reference):
+ prediction = self.parse_digits(prediction)
+ reference = self.parse_digits(reference)
+ return isclose(prediction, reference, abs_tol=1e-3)
+ except:
+ pass
+
+ try:
+ return self.symbolic_equal(prediction, reference)
+ except:
+ pass
+
+ return False
+
+ def is_digit(self, num):
+ return self.parse_digits(num) is not None
+
+ def parse_digits(self, num):
+ num = regex.sub(",", "", str(num))
+ try:
+ return float(num)
+ except:
+ if num.endswith("%"):
+ num = num[:-1]
+ if num.endswith("\\"):
+ num = num[:-1]
+ try:
+ return float(num) / 100
+ except:
+ pass
+ return None
+
+ def symbolic_equal(self, a, b):
+ def _parse(s):
+ for f in [parse_latex, parse_expr]:
+ try:
+ return f(s)
+ except:
+ pass
+ return s
+
+ a = _parse(a)
+ b = _parse(b)
+
+ try:
+ if simplify(a - b) == 0:
+ return True
+ except:
+ pass
+
+ try:
+ if isclose(N(a), N(b), abs_tol=1e-3):
+ return True
+ except:
+ pass
+ return False
+
+ def get_function_code(self, func):
+ try:
+ source_code = inspect.getsource(func)
+ return source_code
+ except OSError:
+ return "no code"
+
+ @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True)
+ async def _generate_output(self, graph, input_text):
+ return await graph(input_text)
+
+ async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, str, int, float]:
+ input_text = problem["problem"]
+ expected_output = problem["solution"]
+
+ try:
+ output, cost = await self._generate_output(graph, input_text)
+ uni_score, extracted_output = self.calculate_score(expected_output, output)
+
+ if uni_score == 0:
+ self.log_mismatch(
+ input_text,
+ expected_output,
+ output,
+ extracted_output,
+ extract_answer_code=self.get_function_code(self.extract_model_answer),
+ )
+
+ return input_text, output, expected_output, uni_score, cost
+
+ except Exception as e:
+ logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}")
+ return input_text, str(e), expected_output, 0.0, 0.0
+
+ def get_result_columns(self) -> List[str]:
+ return ["question", "prediction", "expected_output", "score", "cost"]
diff --git a/metagpt/ext/aflow/benchmark/mbpp.py b/metagpt/ext/aflow/benchmark/mbpp.py
new file mode 100644
index 000000000..c3628b024
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/mbpp.py
@@ -0,0 +1,121 @@
+import threading
+import time
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+from metagpt.logs import logger
+from metagpt.utils.sanitize import sanitize
+
+
+class MBPPBenchmark(BaseBenchmark):
+ def __init__(self, name: str, file_path: str, log_path: str):
+ super().__init__(name, file_path, log_path)
+
+ class TimeoutError(Exception):
+ pass
+
+ def run_with_timeout(self, func, timeout):
+ result = []
+ stop_event = threading.Event()
+
+ def target():
+ try:
+ result.append(func())
+ except Exception as e:
+ result.append(e)
+ finally:
+ stop_event.set()
+
+ thread = threading.Thread(target=target)
+ thread.start()
+ is_timeout = not stop_event.wait(timeout)
+
+ if is_timeout:
+ raise self.TimeoutError("Function execution timed out")
+
+ if not result:
+ return None
+ if isinstance(result[0], Exception):
+ raise result[0]
+ return result[0]
+
+ def check_solution(self, solution, test, entry_point):
+ solution = sanitize(code=solution, entrypoint=entry_point)
+ try:
+ global_dict = {
+ "math": __import__("math"),
+ "hashlib": __import__("hashlib"),
+ "re": __import__("re"),
+ "List": List,
+ "Dict": Dict,
+ "Tuple": Tuple,
+ "Optional": Optional,
+ "Any": Any,
+ }
+
+ exec(solution, global_dict)
+
+ if entry_point not in global_dict:
+ raise ValueError(f"Function {entry_point} is not defined in the solution.")
+
+ exec(test, global_dict)
+
+ check = global_dict["check"]
+
+ result = self.run_with_timeout(check, 15)
+
+ if result is None:
+ result = (self.PASS, "The solution passed all test cases.")
+
+ except self.TimeoutError:
+ result = (
+ self.FAIL,
+ "Execution timed out. Please check if your solution contains infinite loops or overly time-consuming operations.",
+ )
+ except Exception as e:
+ error_message = f"Error: {str(e)}.\n Solution: {solution}.\n Test: {test}"
+ result = (self.FAIL, error_message)
+
+ with open("error.log", "a", encoding="utf-8") as log_file:
+ log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n")
+
+ return result
+
+ @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True)
+ async def _generate_output(self, graph, prompt, entry_point):
+ return await graph(prompt, entry_point)
+
+ async def evaluate_problem(self, data: dict, graph: Callable) -> Tuple[str, str, str, float, float]:
+ input_text = data["prompt"]
+ expected_output = "\nCorrect Solution:\ndef " + data["code"]
+
+ try:
+ # Generate prediction using the graph function
+ prediction, cost = await self._generate_output(graph, input_text, data["entry_point"])
+
+ # Check the solution
+ ret = self.check_solution(prediction, data["test"], data["entry_point"])
+ test_case_details = ret[1]
+ expected_output = test_case_details + "\nCorrect Solution:" + data["code"]
+
+ # Calculate score based on the check result
+ score = 1.0 if ret[0] == self.PASS else 0.0
+
+ # Log mismatch if the score is 0
+ if score == 0:
+ self.log_mismatch(input_text, expected_output, prediction, score)
+
+ return input_text, prediction, expected_output, score, cost
+
+ except Exception as e:
+ logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}")
+ return input_text, str(e), expected_output, 0.0, 0.0
+
+ def calculate_score(self, expected_output: str, prediction: str) -> Tuple[float, str]:
+ # The scoring logic for MBPP is already implemented in evaluate_problem, this is just to conform to the interface
+ return 0.0, prediction
+
+ def get_result_columns(self) -> List[str]:
+ return ["inputs", "prediction", "expected_output", "score", "cost"]
diff --git a/metagpt/ext/aflow/benchmark/utils.py b/metagpt/ext/aflow/benchmark/utils.py
new file mode 100644
index 000000000..846101bc0
--- /dev/null
+++ b/metagpt/ext/aflow/benchmark/utils.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2024/7/24 16:37
+@Author : didi
+@File : utils.py
+"""
+
+import json
+import os
+
+import numpy as np
+
+from metagpt.utils.common import read_json_file, write_json_file
+
+
+def generate_random_indices(n, n_samples, test=False):
+ """
+ Generate random indices
+ """
+
+ def _set_seed(seed=42):
+ np.random.seed(seed)
+
+ _set_seed()
+ indices = np.arange(n)
+ np.random.shuffle(indices)
+ if test:
+ return indices[n_samples:]
+ else:
+ return indices[:n_samples]
+
+
+def split_data_set(file_path, samples, test=False):
+ data = []
+
+ with open(file_path, "r") as file:
+ for line in file:
+ data.append(json.loads(line))
+ random_indices = generate_random_indices(len(data), samples, test)
+ data = [data[i] for i in random_indices]
+ return data
+
+
+def log_mismatch(problem, expected_output, prediction, predicted_number, path):
+ log_data = {
+ "question": problem,
+ "right_answer": expected_output,
+ "model_output": prediction,
+ "extracted_output": predicted_number,
+ }
+
+ log_file = os.path.join(path, "log.json")
+
+ # Check if the log file already exists
+ if os.path.exists(log_file):
+ # If it exists, load the existing log data
+ data = read_json_file(log_file)
+ else:
+ # If it does not exist, create a new log list
+ data = []
+
+ # Add the new log entry
+ data.append(log_data)
+
+ # Write the data back to log.json file
+ write_json_file(log_file, data, encoding="utf-8", indent=4)
diff --git a/metagpt/ext/aflow/data/download_data.py b/metagpt/ext/aflow/data/download_data.py
new file mode 100644
index 000000000..a3aa2774c
--- /dev/null
+++ b/metagpt/ext/aflow/data/download_data.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# @Date : 2024-10-20
+# @Author : MoshiQAQ & didi
+# @Desc : Download and extract dataset files
+
+import os
+import tarfile
+from typing import Dict
+
+import requests
+from tqdm import tqdm
+
+from metagpt.logs import logger
+
+
+def download_file(url: str, filename: str) -> None:
+ """Download a file from the given URL and show progress."""
+ response = requests.get(url, stream=True)
+ total_size = int(response.headers.get("content-length", 0))
+ block_size = 1024
+ progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)
+
+ with open(filename, "wb") as file:
+ for data in response.iter_content(block_size):
+ size = file.write(data)
+ progress_bar.update(size)
+ progress_bar.close()
+
+
+def extract_tar_gz(filename: str, extract_path: str) -> None:
+ """Extract a tar.gz file to the specified path."""
+ with tarfile.open(filename, "r:gz") as tar:
+ tar.extractall(path=extract_path)
+
+
+def process_dataset(url: str, filename: str, extract_path: str) -> None:
+ """Download, extract, and clean up a dataset."""
+ logger.info(f"Downloading {filename}...")
+ download_file(url, filename)
+
+ logger.info(f"Extracting {filename}...")
+ extract_tar_gz(filename, extract_path)
+
+ logger.info(f"{filename} download and extraction completed.")
+
+ os.remove(filename)
+ logger.info(f"Removed {filename}")
+
+
+# Define the datasets to be downloaded
+# Users can modify this list to choose which datasets to download
+datasets_to_download: Dict[str, Dict[str, str]] = {
+ "datasets": {
+ "url": "https://drive.google.com/uc?export=download&id=1DNoegtZiUhWtvkd2xoIuElmIi4ah7k8e",
+ "filename": "aflow_data.tar.gz",
+ "extract_path": "metagpt/ext/aflow/data",
+ },
+ "results": {
+ "url": "https://drive.google.com/uc?export=download&id=1Sr5wjgKf3bN8OC7G6cO3ynzJqD4w6_Dv",
+ "filename": "result.tar.gz",
+ "extract_path": "metagpt/ext/aflow/data/results",
+ },
+ "initial_rounds": {
+ "url": "https://drive.google.com/uc?export=download&id=1UBoW4WBWjX2gs4I_jq3ALdXeLdwDJMdP",
+ "filename": "initial_rounds.tar.gz",
+ "extract_path": "metagpt/ext/aflow/scripts/optimized",
+ },
+}
+
+
+def download(required_datasets, if_first_download: bool = True):
+ """Main function to process all selected datasets"""
+ if if_first_download:
+ for dataset_name in required_datasets:
+ dataset = datasets_to_download[dataset_name]
+ extract_path = dataset["extract_path"]
+ process_dataset(dataset["url"], dataset["filename"], extract_path)
+ else:
+ logger.info("Skip downloading datasets")
diff --git a/metagpt/ext/aflow/scripts/evaluator.py b/metagpt/ext/aflow/scripts/evaluator.py
new file mode 100644
index 000000000..34bdcd9fc
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/evaluator.py
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+# @Date : 8/23/2024 10:00 AM
+# @Author : all
+# @Desc : Evaluation for different datasets
+
+from typing import Dict, Literal, Tuple
+
+from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark
+from metagpt.ext.aflow.benchmark.drop import DROPBenchmark
+from metagpt.ext.aflow.benchmark.gsm8k import GSM8KBenchmark
+from metagpt.ext.aflow.benchmark.hotpotqa import HotpotQABenchmark
+from metagpt.ext.aflow.benchmark.humaneval import HumanEvalBenchmark
+from metagpt.ext.aflow.benchmark.math import MATHBenchmark
+from metagpt.ext.aflow.benchmark.mbpp import MBPPBenchmark
+
+# If you want to customize tasks, add task types here and provide evaluation functions, just like the ones given above
+DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"]
+
+
+class Evaluator:
+ """
+ Complete the evaluation for different datasets here
+ """
+
+ def __init__(self, eval_path: str):
+ self.eval_path = eval_path
+ self.dataset_configs: Dict[DatasetType, BaseBenchmark] = {
+ "GSM8K": GSM8KBenchmark,
+ "MATH": MATHBenchmark,
+ "HumanEval": HumanEvalBenchmark,
+ "HotpotQA": HotpotQABenchmark,
+ "MBPP": MBPPBenchmark,
+ "DROP": DROPBenchmark,
+ }
+
+ async def graph_evaluate(
+ self, dataset: DatasetType, graph, params: dict, path: str, is_test: bool = False
+ ) -> Tuple[float, float, float]:
+ if dataset not in self.dataset_configs:
+ raise ValueError(f"Unsupported dataset: {dataset}")
+
+ data_path = self._get_data_path(dataset, is_test)
+ benchmark_class = self.dataset_configs[dataset]
+ benchmark = benchmark_class(name=dataset, file_path=data_path, log_path=path)
+
+ # Use params to configure the graph and benchmark
+ configured_graph = await self._configure_graph(dataset, graph, params)
+ if is_test:
+ va_list = None # For test data, generally use None to test all
+ else:
+ va_list = None # Use None to test all Validation data, or set va_list (e.g., [1, 2, 3]) to use partial data
+ return await benchmark.run_evaluation(configured_graph, va_list)
+
+ async def _configure_graph(self, dataset, graph, params: dict):
+ # Here you can configure the graph based on params
+ # For example: set LLM configuration, dataset configuration, etc.
+ dataset_config = params.get("dataset", {})
+ llm_config = params.get("llm_config", {})
+ return graph(name=dataset, llm_config=llm_config, dataset=dataset_config)
+
+ def _get_data_path(self, dataset: DatasetType, test: bool) -> str:
+ base_path = f"metagpt/ext/aflow/data/{dataset.lower()}"
+ return f"{base_path}_test.jsonl" if test else f"{base_path}_validate.jsonl"
diff --git a/metagpt/ext/aflow/scripts/operator.py b/metagpt/ext/aflow/scripts/operator.py
new file mode 100644
index 000000000..903a962e0
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/operator.py
@@ -0,0 +1,360 @@
+# -*- coding: utf-8 -*-
+# @Date : 6/27/2024 17:36 PM
+# @Author : didi
+# @Desc : operator demo of aflow
+import asyncio
+import concurrent.futures
+import random
+import sys
+import traceback
+from collections import Counter
+from typing import Dict, List, Tuple
+
+from tenacity import retry, stop_after_attempt, wait_fixed
+
+from metagpt.actions.action_node import ActionNode
+from metagpt.ext.aflow.scripts.operator_an import (
+ AnswerGenerateOp,
+ CodeGenerateOp,
+ FormatOp,
+ GenerateOp,
+ MdEnsembleOp,
+ ReflectionTestOp,
+ ReviewOp,
+ ReviseOp,
+ ScEnsembleOp,
+)
+from metagpt.ext.aflow.scripts.prompts.prompt import (
+ ANSWER_GENERATION_PROMPT,
+ FORMAT_PROMPT,
+ MD_ENSEMBLE_PROMPT,
+ PYTHON_CODE_VERIFIER_PROMPT,
+ REFLECTION_ON_PUBLIC_TEST_PROMPT,
+ REVIEW_PROMPT,
+ REVISE_PROMPT,
+ SC_ENSEMBLE_PROMPT,
+)
+from metagpt.ext.aflow.scripts.utils import (
+ extract_test_cases_from_jsonl,
+ test_case_2_test_function,
+)
+from metagpt.llm import LLM
+from metagpt.logs import logger
+
+
+class Operator:
+ def __init__(self, llm: LLM, name: str):
+ self.name = name
+ self.llm = llm
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError
+
+ async def _fill_node(self, op_class, prompt, mode=None, **extra_kwargs):
+ fill_kwargs = {"context": prompt, "llm": self.llm}
+ if mode:
+ fill_kwargs["mode"] = mode
+ fill_kwargs.update(extra_kwargs)
+ node = await ActionNode.from_pydantic(op_class).fill(**fill_kwargs)
+ return node.instruct_content.model_dump()
+
+
+class Custom(Operator):
+ def __init__(self, llm: LLM, name: str = "Custom"):
+ super().__init__(llm, name)
+
+ async def __call__(self, input, instruction):
+ prompt = instruction + input
+ response = await self._fill_node(GenerateOp, prompt, mode="single_fill")
+ return response
+
+
+class AnswerGenerate(Operator):
+ def __init__(self, llm: LLM, name: str = "AnswerGenerate"):
+ super().__init__(llm, name)
+
+ async def __call__(self, input: str, mode: str = None) -> Tuple[str, str]:
+ prompt = ANSWER_GENERATION_PROMPT.format(input=input)
+ response = await self._fill_node(AnswerGenerateOp, prompt, mode="xml_fill")
+ return response
+
+
+class CustomCodeGenerate(Operator):
+ def __init__(self, llm: LLM, name: str = "CustomCodeGenerate"):
+ super().__init__(llm, name)
+
+ async def __call__(self, problem, entry_point, instruction):
+ prompt = instruction + problem
+ response = await self._fill_node(GenerateOp, prompt, mode="code_fill", function_name=entry_point)
+ return response
+
+
+class ScEnsemble(Operator):
+ """
+ Paper: Self-Consistency Improves Chain of Thought Reasoning in Language Models
+ Link: https://arxiv.org/abs/2203.11171
+ Paper: Universal Self-Consistency for Large Language Model Generation
+ Link: https://arxiv.org/abs/2311.17311
+ """
+
+ def __init__(self, llm: LLM, name: str = "ScEnsemble"):
+ super().__init__(llm, name)
+
+ async def __call__(self, solutions: List[str], problem: str):
+ answer_mapping = {}
+ solution_text = ""
+ for index, solution in enumerate(solutions):
+ answer_mapping[chr(65 + index)] = index
+ solution_text += f"{chr(65 + index)}: \n{str(solution)}\n\n\n"
+
+ prompt = SC_ENSEMBLE_PROMPT.format(question=problem, solutions=solution_text)
+ response = await self._fill_node(ScEnsembleOp, prompt, mode="xml_fill")
+
+ answer = response.get("solution_letter", "")
+ answer = answer.strip().upper()
+
+ return {"response": solutions[answer_mapping[answer]]}
+
+
+def run_code(code):
+ try:
+ # Create a new global namespace
+ global_namespace = {}
+
+ disallowed_imports = [
+ "os",
+ "sys",
+ "subprocess",
+ "multiprocessing",
+ "matplotlib",
+ "seaborn",
+ "plotly",
+ "bokeh",
+ "ggplot",
+ "pylab",
+ "tkinter",
+ "PyQt5",
+ "wx",
+ "pyglet",
+ ]
+
+ # Check for prohibited imports
+ for lib in disallowed_imports:
+ if f"import {lib}" in code or f"from {lib}" in code:
+ logger.info("Detected prohibited import: %s", lib)
+ return "Error", f"Prohibited import: {lib} and graphing functionalities"
+
+ # Use exec to execute the code
+ exec(code, global_namespace)
+ # Assume the code defines a function named 'solve'
+ if "solve" in global_namespace and callable(global_namespace["solve"]):
+ result = global_namespace["solve"]()
+ return "Success", str(result)
+ else:
+ return "Error", "Function 'solve' not found"
+ except Exception as e:
+ exc_type, exc_value, exc_traceback = sys.exc_info()
+ tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
+ return "Error", f"Execution error: {str(e)}\n{''.join(tb_str)}"
+
+
+class Programmer(Operator):
+ def __init__(self, llm: LLM, name: str = "Programmer"):
+ super().__init__(llm, name)
+
+ async def exec_code(self, code, timeout=30):
+ """
+ Asynchronously execute code and return an error if timeout occurs.
+ """
+ loop = asyncio.get_running_loop()
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
+ try:
+ # Submit run_code task to the process pool
+ future = loop.run_in_executor(executor, run_code, code)
+ # Wait for the task to complete or timeout
+ result = await asyncio.wait_for(future, timeout=timeout)
+ return result
+ except asyncio.TimeoutError:
+ # Timeout, attempt to shut down the process pool
+ executor.shutdown(wait=False, cancel_futures=True)
+ return "Error", "Code execution timed out"
+ except Exception as e:
+ return "Error", f"Unknown error: {str(e)}"
+
+ async def code_generate(self, problem, analysis, feedback, mode):
+ """
+ Asynchronous method to generate code.
+ """
+ prompt = PYTHON_CODE_VERIFIER_PROMPT.format(problem=problem, analysis=analysis, feedback=feedback)
+ response = await self._fill_node(CodeGenerateOp, prompt, mode, function_name="solve")
+ return response
+
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
+ async def __call__(self, problem: str, analysis: str = "None"):
+ """
+ Call method, generate code and execute, retry up to 3 times.
+ """
+ code = None
+ output = None
+ feedback = ""
+ for i in range(3):
+ code_response = await self.code_generate(problem, analysis, feedback, mode="code_fill")
+ code = code_response.get("code")
+ if not code:
+ return {"code": code, "output": "No code generated"}
+ status, output = await self.exec_code(code)
+ if status == "Success":
+ return {"code": code, "output": output}
+ else:
+ logger.info(f"Execution error on attempt {i + 1}, error message: {output}")
+ feedback = (
+ f"\nThe result of the error from the code you wrote in the previous round:\n"
+ f"Code: {code}\n\nStatus: {status}, {output}"
+ )
+ return {"code": code, "output": output}
+
+
+class Test(Operator):
+ def __init__(self, llm: LLM, name: str = "Test"):
+ super().__init__(llm, name)
+
+ def exec_code(self, solution, entry_point):
+ test_cases = extract_test_cases_from_jsonl(entry_point)
+
+ fail_cases = []
+ for test_case in test_cases:
+ test_code = test_case_2_test_function(solution, test_case, entry_point)
+ try:
+ exec(test_code, globals())
+ except AssertionError as e:
+ exc_type, exc_value, exc_traceback = sys.exc_info()
+ tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
+ with open("tester.txt", "a") as f:
+ f.write("test_error of " + entry_point + "\n")
+ error_infomation = {
+ "test_fail_case": {
+ "test_case": test_case,
+ "error_type": "AssertionError",
+ "error_message": str(e),
+ "traceback": tb_str,
+ }
+ }
+ fail_cases.append(error_infomation)
+ except Exception as e:
+ with open("tester.txt", "a") as f:
+ f.write(entry_point + " " + str(e) + "\n")
+ return {"exec_fail_case": str(e)}
+ if fail_cases != []:
+ return fail_cases
+ else:
+ return "no error"
+
+ async def __call__(self, problem, solution, entry_point, test_loop: int = 3):
+ """
+ "Test": {
+ "description": "Test the solution with test cases, if the solution is correct, return 'no error', if the solution is incorrect, return reflect on the soluion and the error information",
+ "interface": "test(problem: str, solution: str, entry_point: str) -> str"
+ }
+ """
+ for _ in range(test_loop):
+ result = self.exec_code(solution, entry_point)
+ if result == "no error":
+ return {"result": True, "solution": solution}
+ elif "exec_fail_case" in result:
+ result = result["exec_fail_case"]
+ prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format(
+ problem=problem,
+ solution=solution,
+ exec_pass=f"executed unsuccessfully, error: \n {result}",
+ test_fail="executed unsucessfully",
+ )
+ response = await self._fill_node(ReflectionTestOp, prompt, mode="code_fill")
+ solution = response["reflection_and_solution"]
+ else:
+ prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format(
+ problem=problem,
+ solution=solution,
+ exec_pass="executed successfully",
+ test_fail=result,
+ )
+ response = await self._fill_node(ReflectionTestOp, prompt, mode="code_fill")
+ solution = response["reflection_and_solution"]
+
+ result = self.exec_code(solution, entry_point)
+ if result == "no error":
+ return {"result": True, "solution": solution}
+ else:
+ return {"result": False, "solution": solution}
+
+
+class Format(Operator):
+ def __init__(self, llm: LLM, name: str = "Format"):
+ super().__init__(llm, name)
+
+ async def __call__(self, problem, solution, mode: str = None):
+ prompt = FORMAT_PROMPT.format(problem_description=problem, solution=solution)
+ response = await self._fill_node(FormatOp, prompt, mode)
+ return response
+
+
+class Review(Operator):
+ def __init__(self, llm: LLM, name: str = "Review"):
+ super().__init__(llm, name)
+
+ async def __call__(self, problem, solution, mode: str = None):
+ prompt = REVIEW_PROMPT.format(problem=problem, solution=solution)
+ response = await self._fill_node(ReviewOp, prompt, mode="xml_fill")
+ return response
+
+
+class Revise(Operator):
+ def __init__(self, llm: LLM, name: str = "Revise"):
+ super().__init__(llm, name)
+
+ async def __call__(self, problem, solution, feedback, mode: str = None):
+ prompt = REVISE_PROMPT.format(problem=problem, solution=solution, feedback=feedback)
+ response = await self._fill_node(ReviseOp, prompt, mode="xml_fill")
+ return response
+
+
+class MdEnsemble(Operator):
+ """
+ Paper: Can Generalist Foundation Models Outcompete Special-Purpose Tuning? Case Study in Medicine
+ Link: https://arxiv.org/abs/2311.16452
+ """
+
+ def __init__(self, llm: LLM, name: str = "MdEnsemble", vote_count: int = 5):
+ super().__init__(llm, name)
+ self.vote_count = vote_count
+
+ @staticmethod
+ def shuffle_answers(solutions: List[str]) -> Tuple[List[str], Dict[str, str]]:
+ shuffled_solutions = solutions.copy()
+ random.shuffle(shuffled_solutions)
+ answer_mapping = {chr(65 + i): solutions.index(solution) for i, solution in enumerate(shuffled_solutions)}
+ return shuffled_solutions, answer_mapping
+
+ async def __call__(self, solutions: List[str], problem: str, mode: str = None):
+ logger.info(f"solution count: {len(solutions)}")
+ all_responses = []
+
+ for _ in range(self.vote_count):
+ shuffled_solutions, answer_mapping = self.shuffle_answers(solutions)
+
+ solution_text = ""
+ for index, solution in enumerate(shuffled_solutions):
+ solution_text += f"{chr(65 + index)}: \n{str(solution)}\n\n\n"
+
+ prompt = MD_ENSEMBLE_PROMPT.format(solutions=solution_text, question=problem)
+ response = await self._fill_node(MdEnsembleOp, prompt, mode="xml_fill")
+
+ answer = response.get("solution_letter", "A")
+ answer = answer.strip().upper()
+
+ if answer in answer_mapping:
+ original_index = answer_mapping[answer]
+ all_responses.append(original_index)
+
+ most_frequent_index = Counter(all_responses).most_common(1)[0][0]
+ final_answer = solutions[most_frequent_index]
+ return {"solution": final_answer}
diff --git a/metagpt/ext/aflow/scripts/operator_an.py b/metagpt/ext/aflow/scripts/operator_an.py
new file mode 100644
index 000000000..d0201dea2
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/operator_an.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# @Date : 6/27/2024 19:46 PM
+# @Author : didi
+# @Desc : action nodes for operator
+
+from pydantic import BaseModel, Field
+
+
+class GenerateOp(BaseModel):
+ response: str = Field(default="", description="Your solution for this problem")
+
+
+class CodeGenerateOp(BaseModel):
+ code: str = Field(default="", description="Your complete code solution for this problem")
+
+
+class AnswerGenerateOp(BaseModel):
+ thought: str = Field(default="", description="The step by step thinking process")
+ answer: str = Field(default="", description="The final answer to the question")
+
+
+class FormatOp(BaseModel):
+ solution: str = Field(default="", description="Your formatted answer for this problem")
+
+
+class ScEnsembleOp(BaseModel):
+ thought: str = Field(default="", description="The thought of the most consistent solution.")
+ solution_letter: str = Field(default="", description="The letter of most consistent solution.")
+
+
+class ReflectionTestOp(BaseModel):
+ reflection_and_solution: str = Field(
+ default="", description="Corrective solution for code execution errors or test case failures"
+ )
+
+
+class MdEnsembleOp(BaseModel):
+ thought: str = Field(default="", description="Step-by-step analysis of the solutions to determine the best one.")
+ solution_letter: str = Field(default="", description="The letter of the chosen best solution (only one letter).")
+
+
+class ReviewOp(BaseModel):
+ review_result: bool = Field(
+ default=False,
+ description="The Review Result (Bool). If you think this solution looks good for you, return 'true'; If not, return 'false'",
+ )
+ feedback: str = Field(
+ default="",
+ description="Your FeedBack for this problem based on the criteria. If the review result is true, you can put it 'nothing here'.",
+ )
+
+
+class ReviseOp(BaseModel):
+ solution: str = Field(default="", description="Based on the feedback, revised solution for this problem")
diff --git a/metagpt/ext/aflow/scripts/optimized/__init__.py b/metagpt/ext/aflow/scripts/optimized/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/metagpt/ext/aflow/scripts/optimizer.py b/metagpt/ext/aflow/scripts/optimizer.py
new file mode 100644
index 000000000..0ac4827e7
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/optimizer.py
@@ -0,0 +1,199 @@
+# -*- coding: utf-8 -*-
+# @Date : 8/12/2024 22:00 PM
+# @Author : issac
+# @Desc : optimizer for graph
+
+import asyncio
+import time
+from typing import List, Literal
+
+from pydantic import BaseModel, Field
+
+from metagpt.actions.action_node import ActionNode
+from metagpt.ext.aflow.scripts.evaluator import DatasetType
+from metagpt.ext.aflow.scripts.optimizer_utils.convergence_utils import ConvergenceUtils
+from metagpt.ext.aflow.scripts.optimizer_utils.data_utils import DataUtils
+from metagpt.ext.aflow.scripts.optimizer_utils.evaluation_utils import EvaluationUtils
+from metagpt.ext.aflow.scripts.optimizer_utils.experience_utils import ExperienceUtils
+from metagpt.ext.aflow.scripts.optimizer_utils.graph_utils import GraphUtils
+from metagpt.logs import logger
+from metagpt.provider.llm_provider_registry import create_llm_instance
+
+QuestionType = Literal["math", "code", "qa"]
+OptimizerType = Literal["Graph", "Test"]
+
+
+class GraphOptimize(BaseModel):
+ modification: str = Field(default="", description="modification")
+ graph: str = Field(default="", description="graph")
+ prompt: str = Field(default="", description="prompt")
+
+
+class Optimizer:
+ def __init__(
+ self,
+ dataset: DatasetType,
+ question_type: QuestionType,
+ opt_llm_config,
+ exec_llm_config,
+ operators: List,
+ sample: int,
+ check_convergence: bool = False,
+ optimized_path: str = None,
+ initial_round: int = 1,
+ max_rounds: int = 20,
+ validation_rounds: int = 5,
+ ) -> None:
+ self.optimize_llm_config = opt_llm_config
+ self.optimize_llm = create_llm_instance(self.optimize_llm_config)
+ self.execute_llm_config = exec_llm_config
+
+ self.dataset = dataset
+ self.type = question_type
+ self.check_convergence = check_convergence
+
+ self.graph = None
+ self.operators = operators
+
+ self.root_path = f"{optimized_path}/{self.dataset}"
+ self.sample = sample
+ self.top_scores = []
+ self.round = initial_round
+ self.max_rounds = max_rounds
+ self.validation_rounds = validation_rounds
+
+ self.graph_utils = GraphUtils(self.root_path)
+ self.data_utils = DataUtils(self.root_path)
+ self.experience_utils = ExperienceUtils(self.root_path)
+ self.evaluation_utils = EvaluationUtils(self.root_path)
+ self.convergence_utils = ConvergenceUtils(self.root_path)
+
+ def optimize(self, mode: OptimizerType = "Graph"):
+ if mode == "Test":
+ test_n = 3 # validation datasets's execution number
+ for i in range(test_n):
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ score = loop.run_until_complete(self.test())
+ return None
+
+ for opt_round in range(self.max_rounds):
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ retry_count = 0
+ max_retries = 1
+
+ while retry_count < max_retries:
+ try:
+ score = loop.run_until_complete(self._optimize_graph())
+ break
+ except Exception as e:
+ retry_count += 1
+ logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})")
+ if retry_count == max_retries:
+ logger.info("Max retries reached. Moving to next round.")
+ score = None
+
+ wait_time = 5 * retry_count
+ time.sleep(wait_time)
+
+ if retry_count < max_retries:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ self.round += 1
+ logger.info(f"Score for round {self.round}: {score}")
+
+ converged, convergence_round, final_round = self.convergence_utils.check_convergence(top_k=3)
+
+ if converged and self.check_convergence:
+ logger.info(
+ f"Convergence detected, occurred in round {convergence_round}, final round is {final_round}"
+ )
+ # Print average scores and standard deviations for each round
+ self.convergence_utils.print_results()
+ break
+
+ time.sleep(5)
+
+ async def _optimize_graph(self):
+ validation_n = self.validation_rounds # validation datasets's execution number
+ graph_path = f"{self.root_path}/workflows"
+ data = self.data_utils.load_results(graph_path)
+
+ if self.round == 1:
+ directory = self.graph_utils.create_round_directory(graph_path, self.round)
+ # Load graph using graph_utils
+ self.graph = self.graph_utils.load_graph(self.round, graph_path)
+ avg_score = await self.evaluation_utils.evaluate_graph(self, directory, validation_n, data, initial=True)
+
+ # Create a loop until the generated graph meets the check conditions
+ while True:
+ directory = self.graph_utils.create_round_directory(graph_path, self.round + 1)
+
+ top_rounds = self.data_utils.get_top_rounds(self.sample)
+ sample = self.data_utils.select_round(top_rounds)
+
+ prompt, graph_load = self.graph_utils.read_graph_files(sample["round"], graph_path)
+ graph = self.graph_utils.extract_solve_graph(graph_load)
+
+ processed_experience = self.experience_utils.load_experience()
+ experience = self.experience_utils.format_experience(processed_experience, sample["round"])
+
+ operator_description = self.graph_utils.load_operators_description(self.operators)
+ log_data = self.data_utils.load_log(sample["round"])
+
+ graph_optimize_prompt = self.graph_utils.create_graph_optimize_prompt(
+ experience, sample["score"], graph[0], prompt, operator_description, self.type, log_data
+ )
+
+ graph_optimize_node = await ActionNode.from_pydantic(GraphOptimize).fill(
+ context=graph_optimize_prompt, mode="xml_fill", llm=self.optimize_llm
+ )
+
+ response = await self.graph_utils.get_graph_optimize_response(graph_optimize_node)
+
+ # Check if the modification meets the conditions
+ check = self.experience_utils.check_modification(
+ processed_experience, response["modification"], sample["round"]
+ )
+
+ # If `check` is True, break the loop; otherwise, regenerate the graph
+ if check:
+ break
+
+ # Save the graph and evaluate
+ self.graph_utils.write_graph_files(directory, response, self.round + 1, self.dataset)
+
+ experience = self.experience_utils.create_experience_data(sample, response["modification"])
+
+ self.graph = self.graph_utils.load_graph(self.round + 1, graph_path)
+
+ logger.info(directory)
+
+ avg_score = await self.evaluation_utils.evaluate_graph(self, directory, validation_n, data, initial=False)
+
+ self.experience_utils.update_experience(directory, experience, avg_score)
+
+ return avg_score
+
+ async def test(self):
+ rounds = [5] # You can choose the rounds you want to test here.
+ data = []
+
+ graph_path = f"{self.root_path}/workflows_test"
+ json_file_path = self.data_utils.get_results_file_path(graph_path)
+
+ data = self.data_utils.load_results(graph_path)
+
+ for round in rounds:
+ directory = self.graph_utils.create_round_directory(graph_path, round)
+ self.graph = self.graph_utils.load_graph(round, graph_path)
+
+ score, avg_cost, total_cost = await self.evaluation_utils.evaluate_graph_test(self, directory, is_test=True)
+
+ new_data = self.data_utils.create_result_data(round, score, avg_cost, total_cost)
+ data.append(new_data)
+
+ self.data_utils.save_results(json_file_path, data)
diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/convergence_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/convergence_utils.py
new file mode 100644
index 000000000..0e275f496
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/optimizer_utils/convergence_utils.py
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+# @Date : 9/23/2024 10:00 AM
+# @Author : Issac
+# @Desc :
+
+import json
+import os
+
+import numpy as np
+
+from metagpt.logs import logger
+
+
+class ConvergenceUtils:
+ def __init__(self, root_path):
+ self.root_path = root_path
+ self.data = None
+ self.rounds = None
+ self.avg_scores, self.stds = None, None
+
+ def load_data(self, root_path):
+ """
+ Read JSON file, create a new file if it doesn't exist, then return the data.
+ """
+ rounds_dir = os.path.join(root_path, "workflows")
+ result_file = os.path.join(rounds_dir, "results.json")
+
+ # Ensure directory exists
+ os.makedirs(rounds_dir, exist_ok=True)
+
+ # If file doesn't exist, create a new one with an empty list
+ if not os.path.exists(result_file):
+ with open(result_file, "w") as file:
+ json.dump([], file)
+
+ # Read file and return data
+ with open(result_file, "r") as file:
+ return json.load(file)
+
+ def process_rounds(self):
+ """
+ Organize data by round, return a dictionary of scores by round.
+ """
+ self.data = self.load_data(root_path=self.root_path)
+ rounds = {}
+ for entry in self.data:
+ round_number = entry["round"]
+ score = entry["score"]
+ if round_number not in rounds:
+ rounds[round_number] = []
+ rounds[round_number].append(score)
+ return rounds
+
+ def calculate_avg_and_std(self):
+ """
+ Calculate average score and standard deviation for each round, return two lists: average scores and standard deviations.
+ """
+ self.rounds = self.process_rounds()
+
+ sorted_rounds = sorted(self.rounds.items(), key=lambda x: x[0])
+ avg_scores = []
+ stds = []
+ for round_number, scores in sorted_rounds:
+ avg_scores.append(np.mean(scores))
+ stds.append(np.std(scores))
+ return avg_scores, stds
+
+ def check_convergence(self, top_k=3, z=0, consecutive_rounds=5):
+ """
+ Check for convergence. z is the z-score corresponding to the confidence level.
+ consecutive_rounds is the number of consecutive rounds that must meet the stop condition.
+ """
+ # Calculate average score and standard deviation for each round
+ self.avg_scores, self.stds = self.calculate_avg_and_std()
+ # If total rounds are not enough to calculate top_k+1 rounds, return not converged
+ if len(self.avg_scores) < top_k + 1:
+ return False, None, None
+ convergence_count = 0 # Convergence counter
+ previous_y = None # Y value of the previous round (average of top_k scores)
+ sigma_y_previous = None # Standard error of Y value from previous round
+ for i in range(len(self.avg_scores)):
+ # Dynamically select top_k from current round and all previous rounds
+ top_k_indices = np.argsort(self.avg_scores[: i + 1])[::-1][
+ :top_k
+ ] # Select top k indices by descending average score
+ top_k_scores = [self.avg_scores[j] for j in top_k_indices] # Get list of top k scores
+ top_k_stds = [
+ self.stds[j] for j in top_k_indices
+ ] # Get list of standard deviations corresponding to top k scores
+ # Calculate mean of top k scores for current round, i.e., y_current
+ y_current = np.mean(top_k_scores)
+ # Calculate standard error of y_current (sigma_y_current), representing score dispersion
+ sigma_y_current = np.sqrt(np.sum([s**2 for s in top_k_stds]) / (top_k**2))
+ # If not the first round, calculate change in Y (Delta_Y) and corresponding standard error
+ if previous_y is not None:
+ # Calculate Y difference between current round and previous round
+ delta_y = y_current - previous_y
+ # Calculate standard error of Y difference (sigma_Delta_Y)
+ sigma_delta_y = np.sqrt(sigma_y_current**2 + sigma_y_previous**2)
+ # Check if Y change is within acceptable confidence interval, i.e., convergence condition
+ if abs(delta_y) <= z * sigma_delta_y:
+ convergence_count += 1
+ # If consecutive converged rounds reach set value, return convergence information
+ if convergence_count >= consecutive_rounds:
+ return True, i - consecutive_rounds + 1, i
+ else:
+ # If change is large, reset convergence counter
+ convergence_count = 0
+ # Update Y value and standard error for previous round
+ previous_y = y_current
+ sigma_y_previous = sigma_y_current
+ # If convergence condition not met, return not converged
+ return False, None, None
+
+ def print_results(self):
+ """
+ Print average score and standard deviation for all rounds.
+ """
+ self.avg_scores, self.stds = self.calculate_avg_and_std()
+ for i, (avg_score, std) in enumerate(zip(self.avg_scores, self.stds), 1):
+ logger.info(f"Round {i}: Average Score = {avg_score:.4f}, Standard Deviation = {std:.4f}")
+
+
+if __name__ == "__main__":
+ # Use this class and specify top_k
+ checker = ConvergenceUtils("path") # For example, set top_k=5
+ converged, convergence_round, final_round = checker.check_convergence()
+
+ if converged:
+ logger.info(f"Convergence detected, occurred at round {convergence_round}, final round is {final_round}")
+ else:
+ logger.info("No convergence detected within all rounds")
+
+ # Print average score and standard deviation for each round
+ checker.print_results()
diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py
new file mode 100644
index 000000000..2a09e0820
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py
@@ -0,0 +1,149 @@
+import datetime
+import json
+import os
+import random
+
+import numpy as np
+import pandas as pd
+
+from metagpt.logs import logger
+from metagpt.utils.common import read_json_file, write_json_file
+
+
+class DataUtils:
+ def __init__(self, root_path: str):
+ self.root_path = root_path
+ self.top_scores = []
+
+ def load_results(self, path: str) -> list:
+ result_path = os.path.join(path, "results.json")
+ if os.path.exists(result_path):
+ with open(result_path, "r") as json_file:
+ try:
+ return json.load(json_file)
+ except json.JSONDecodeError:
+ return []
+ return []
+
+ def get_top_rounds(self, sample: int, path=None, mode="Graph"):
+ self._load_scores(path, mode)
+ unique_rounds = set()
+ unique_top_scores = []
+
+ first_round = next((item for item in self.top_scores if item["round"] == 1), None)
+ if first_round:
+ unique_top_scores.append(first_round)
+ unique_rounds.add(1)
+
+ for item in self.top_scores:
+ if item["round"] not in unique_rounds:
+ unique_top_scores.append(item)
+ unique_rounds.add(item["round"])
+
+ if len(unique_top_scores) >= sample:
+ break
+
+ return unique_top_scores
+
+ def select_round(self, items):
+ if not items:
+ raise ValueError("Item list is empty.")
+
+ sorted_items = sorted(items, key=lambda x: x["score"], reverse=True)
+ scores = [item["score"] * 100 for item in sorted_items]
+
+ probabilities = self._compute_probabilities(scores)
+ logger.info(f"\nMixed probability distribution: {probabilities}")
+ logger.info(f"\nSorted rounds: {sorted_items}")
+
+ selected_index = np.random.choice(len(sorted_items), p=probabilities)
+ logger.info(f"\nSelected index: {selected_index}, Selected item: {sorted_items[selected_index]}")
+
+ return sorted_items[selected_index]
+
+ def _compute_probabilities(self, scores, alpha=0.2, lambda_=0.3):
+ scores = np.array(scores, dtype=np.float64)
+ n = len(scores)
+
+ if n == 0:
+ raise ValueError("Score list is empty.")
+
+ uniform_prob = np.full(n, 1.0 / n, dtype=np.float64)
+
+ max_score = np.max(scores)
+ shifted_scores = scores - max_score
+ exp_weights = np.exp(alpha * shifted_scores)
+
+ sum_exp_weights = np.sum(exp_weights)
+ if sum_exp_weights == 0:
+ raise ValueError("Sum of exponential weights is 0, cannot normalize.")
+
+ score_prob = exp_weights / sum_exp_weights
+
+ mixed_prob = lambda_ * uniform_prob + (1 - lambda_) * score_prob
+
+ total_prob = np.sum(mixed_prob)
+ if not np.isclose(total_prob, 1.0):
+ mixed_prob = mixed_prob / total_prob
+
+ return mixed_prob
+
+ def load_log(self, cur_round, path=None, mode: str = "Graph"):
+ if mode == "Graph":
+ log_dir = os.path.join(self.root_path, "workflows", f"round_{cur_round}", "log.json")
+ else:
+ log_dir = path
+
+ # 检查文件是否存在
+ if not os.path.exists(log_dir):
+ return "" # 如果文件不存在,返回空字符串
+ logger.info(log_dir)
+ data = read_json_file(log_dir, encoding="utf-8")
+
+ if isinstance(data, dict):
+ data = [data]
+ elif not isinstance(data, list):
+ data = list(data)
+
+ if not data:
+ return ""
+
+ sample_size = min(3, len(data))
+ random_samples = random.sample(data, sample_size)
+
+ log = ""
+ for sample in random_samples:
+ log += json.dumps(sample, indent=4, ensure_ascii=False) + "\n\n"
+
+ return log
+
+ def get_results_file_path(self, graph_path: str) -> str:
+ return os.path.join(graph_path, "results.json")
+
+ def create_result_data(self, round: int, score: float, avg_cost: float, total_cost: float) -> dict:
+ now = datetime.datetime.now()
+ return {"round": round, "score": score, "avg_cost": avg_cost, "total_cost": total_cost, "time": now}
+
+ def save_results(self, json_file_path: str, data: list):
+ write_json_file(json_file_path, data, encoding="utf-8", indent=4)
+
+ def _load_scores(self, path=None, mode="Graph"):
+ if mode == "Graph":
+ rounds_dir = os.path.join(self.root_path, "workflows")
+ else:
+ rounds_dir = path
+
+ result_file = os.path.join(rounds_dir, "results.json")
+ self.top_scores = []
+
+ data = read_json_file(result_file, encoding="utf-8")
+ df = pd.DataFrame(data)
+
+ scores_per_round = df.groupby("round")["score"].mean().to_dict()
+
+ for round_number, average_score in scores_per_round.items():
+ self.top_scores.append({"round": round_number, "score": average_score})
+
+ self.top_scores.sort(key=lambda x: x["score"], reverse=True)
+
+ return self.top_scores
diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/evaluation_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/evaluation_utils.py
new file mode 100644
index 000000000..77683017e
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/optimizer_utils/evaluation_utils.py
@@ -0,0 +1,63 @@
+from metagpt.ext.aflow.scripts.evaluator import Evaluator
+
+
+class EvaluationUtils:
+ def __init__(self, root_path: str):
+ self.root_path = root_path
+
+ async def evaluate_initial_round(self, optimizer, graph_path, directory, validation_n, data):
+ # 使用 optimizer 的 graph_utils 来加载图
+ optimizer.graph = optimizer.graph_utils.load_graph(optimizer.round, graph_path)
+ evaluator = Evaluator(eval_path=directory)
+
+ for i in range(validation_n):
+ score, avg_cost, total_cost = await evaluator.graph_evaluate(
+ optimizer.dataset,
+ optimizer.graph,
+ {"dataset": optimizer.dataset, "llm_config": optimizer.execute_llm_config},
+ directory,
+ is_test=False,
+ )
+
+ new_data = optimizer.data_utils.create_result_data(optimizer.round, score, avg_cost, total_cost)
+ data.append(new_data)
+
+ result_path = optimizer.data_utils.get_results_file_path(graph_path)
+ optimizer.data_utils.save_results(result_path, data)
+
+ return data
+
+ async def evaluate_graph(self, optimizer, directory, validation_n, data, initial=False):
+ evaluator = Evaluator(eval_path=directory)
+ sum_score = 0
+
+ for i in range(validation_n):
+ score, avg_cost, total_cost = await evaluator.graph_evaluate(
+ optimizer.dataset,
+ optimizer.graph,
+ {"dataset": optimizer.dataset, "llm_config": optimizer.execute_llm_config},
+ directory,
+ is_test=False,
+ )
+
+ cur_round = optimizer.round + 1 if initial is False else optimizer.round
+
+ new_data = optimizer.data_utils.create_result_data(cur_round, score, avg_cost, total_cost)
+ data.append(new_data)
+
+ result_path = optimizer.data_utils.get_results_file_path(f"{optimizer.root_path}/workflows")
+ optimizer.data_utils.save_results(result_path, data)
+
+ sum_score += score
+
+ return sum_score / validation_n
+
+ async def evaluate_graph_test(self, optimizer, directory, is_test=True):
+ evaluator = Evaluator(eval_path=directory)
+ return await evaluator.graph_evaluate(
+ optimizer.dataset,
+ optimizer.graph,
+ {"dataset": optimizer.dataset, "llm_config": optimizer.execute_llm_config},
+ directory,
+ is_test=is_test,
+ )
diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/experience_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/experience_utils.py
new file mode 100644
index 000000000..43f9eb1d5
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/optimizer_utils/experience_utils.py
@@ -0,0 +1,96 @@
+import json
+import os
+from collections import defaultdict
+
+from metagpt.logs import logger
+from metagpt.utils.common import read_json_file, write_json_file
+
+
+class ExperienceUtils:
+ def __init__(self, root_path: str):
+ self.root_path = root_path
+
+ def load_experience(self, path=None, mode: str = "Graph"):
+ if mode == "Graph":
+ rounds_dir = os.path.join(self.root_path, "workflows")
+ else:
+ rounds_dir = path
+
+ experience_data = defaultdict(lambda: {"score": None, "success": {}, "failure": {}})
+
+ for round_dir in os.listdir(rounds_dir):
+ if os.path.isdir(os.path.join(rounds_dir, round_dir)) and round_dir.startswith("round_"):
+ round_path = os.path.join(rounds_dir, round_dir)
+ try:
+ round_number = int(round_dir.split("_")[1])
+ json_file_path = os.path.join(round_path, "experience.json")
+ if os.path.exists(json_file_path):
+ data = read_json_file(json_file_path, encoding="utf-8")
+ father_node = data["father node"]
+
+ if experience_data[father_node]["score"] is None:
+ experience_data[father_node]["score"] = data["before"]
+
+ if data["succeed"]:
+ experience_data[father_node]["success"][round_number] = {
+ "modification": data["modification"],
+ "score": data["after"],
+ }
+ else:
+ experience_data[father_node]["failure"][round_number] = {
+ "modification": data["modification"],
+ "score": data["after"],
+ }
+ except Exception as e:
+ logger.info(f"Error processing {round_dir}: {str(e)}")
+
+ experience_data = dict(experience_data)
+
+ output_path = os.path.join(rounds_dir, "processed_experience.json")
+ with open(output_path, "w", encoding="utf-8") as outfile:
+ json.dump(experience_data, outfile, indent=4, ensure_ascii=False)
+
+ logger.info(f"Processed experience data saved to {output_path}")
+ return experience_data
+
+ def format_experience(self, processed_experience, sample_round):
+ experience_data = processed_experience.get(sample_round)
+ if experience_data:
+ experience = f"Original Score: {experience_data['score']}\n"
+ experience += "These are some conclusions drawn from experience:\n\n"
+ for key, value in experience_data["failure"].items():
+ experience += f"-Absolutely prohibit {value['modification']} (Score: {value['score']})\n"
+ for key, value in experience_data["success"].items():
+ experience += f"-Absolutely prohibit {value['modification']} \n"
+ experience += "\n\nNote: Take into account past failures and avoid repeating the same mistakes, as these failures indicate that these approaches are ineffective. You must fundamentally change your way of thinking, rather than simply using more advanced Python syntax like for, if, else, etc., or modifying the prompt."
+ else:
+ experience = f"No experience data found for round {sample_round}."
+ return experience
+
+ def check_modification(self, processed_experience, modification, sample_round):
+ experience_data = processed_experience.get(sample_round)
+ if experience_data:
+ for key, value in experience_data["failure"].items():
+ if value["modification"] == modification:
+ return False
+ for key, value in experience_data["success"].items():
+ if value["modification"] == modification:
+ return False
+ return True
+ else:
+ return True # 如果 experience_data 为空,也返回 True
+
+ def create_experience_data(self, sample, modification):
+ return {
+ "father node": sample["round"],
+ "modification": modification,
+ "before": sample["score"],
+ "after": None,
+ "succeed": None,
+ }
+
+ def update_experience(self, directory, experience, avg_score):
+ experience["after"] = avg_score
+ experience["succeed"] = bool(avg_score > experience["before"])
+
+ write_json_file(os.path.join(directory, "experience.json"), experience, encoding="utf-8", indent=4)
diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/graph_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/graph_utils.py
new file mode 100644
index 000000000..a0ebe9b26
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/optimizer_utils/graph_utils.py
@@ -0,0 +1,125 @@
+import json
+import os
+import re
+import time
+import traceback
+from typing import List
+
+from metagpt.ext.aflow.scripts.prompts.optimize_prompt import (
+ WORKFLOW_CUSTOM_USE,
+ WORKFLOW_INPUT,
+ WORKFLOW_OPTIMIZE_PROMPT,
+ WORKFLOW_TEMPLATE,
+)
+from metagpt.logs import logger
+
+
+class GraphUtils:
+ def __init__(self, root_path: str):
+ self.root_path = root_path
+
+ def create_round_directory(self, graph_path: str, round_number: int) -> str:
+ directory = os.path.join(graph_path, f"round_{round_number}")
+ os.makedirs(directory, exist_ok=True)
+ return directory
+
+ def load_graph(self, round_number: int, workflows_path: str):
+ workflows_path = workflows_path.replace("\\", ".").replace("/", ".")
+ graph_module_name = f"{workflows_path}.round_{round_number}.graph"
+
+ try:
+ graph_module = __import__(graph_module_name, fromlist=[""])
+ graph_class = getattr(graph_module, "Workflow")
+ return graph_class
+ except ImportError as e:
+ logger.info(f"Error loading graph for round {round_number}: {e}")
+ raise
+
+ def read_graph_files(self, round_number: int, workflows_path: str):
+ prompt_file_path = os.path.join(workflows_path, f"round_{round_number}", "prompt.py")
+ graph_file_path = os.path.join(workflows_path, f"round_{round_number}", "graph.py")
+
+ try:
+ with open(prompt_file_path, "r", encoding="utf-8") as file:
+ prompt_content = file.read()
+ with open(graph_file_path, "r", encoding="utf-8") as file:
+ graph_content = file.read()
+ except FileNotFoundError as e:
+ logger.info(f"Error: File not found for round {round_number}: {e}")
+ raise
+ except Exception as e:
+ logger.info(f"Error loading prompt for round {round_number}: {e}")
+ raise
+ return prompt_content, graph_content
+
+ def extract_solve_graph(self, graph_load: str) -> List[str]:
+ pattern = r"class Workflow:.+"
+ return re.findall(pattern, graph_load, re.DOTALL)
+
+ def load_operators_description(self, operators: List[str]) -> str:
+ path = f"{self.root_path}/workflows/template/operator.json"
+ operators_description = ""
+ for id, operator in enumerate(operators):
+ operator_description = self._load_operator_description(id + 1, operator, path)
+ operators_description += f"{operator_description}\n"
+ return operators_description
+
+ def _load_operator_description(self, id: int, operator_name: str, file_path: str) -> str:
+ with open(file_path, "r") as f:
+ operator_data = json.load(f)
+ matched_data = operator_data[operator_name]
+ desc = matched_data["description"]
+ interface = matched_data["interface"]
+ return f"{id}. {operator_name}: {desc}, with interface {interface})."
+
+ def create_graph_optimize_prompt(
+ self,
+ experience: str,
+ score: float,
+ graph: str,
+ prompt: str,
+ operator_description: str,
+ type: str,
+ log_data: str,
+ ) -> str:
+ graph_input = WORKFLOW_INPUT.format(
+ experience=experience,
+ score=score,
+ graph=graph,
+ prompt=prompt,
+ operator_description=operator_description,
+ type=type,
+ log=log_data,
+ )
+ graph_system = WORKFLOW_OPTIMIZE_PROMPT.format(type=type)
+ return graph_input + WORKFLOW_CUSTOM_USE + graph_system
+
+ async def get_graph_optimize_response(self, graph_optimize_node):
+ max_retries = 5
+ retries = 0
+
+ while retries < max_retries:
+ try:
+ response = graph_optimize_node.instruct_content.model_dump()
+ return response
+ except Exception as e:
+ retries += 1
+ logger.info(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
+ if retries == max_retries:
+ logger.info("Maximum retries reached. Skipping this sample.")
+ break
+ traceback.print_exc()
+ time.sleep(5)
+ return None
+
+ def write_graph_files(self, directory: str, response: dict, round_number: int, dataset: str):
+ graph = WORKFLOW_TEMPLATE.format(graph=response["graph"], round=round_number, dataset=dataset)
+
+ with open(os.path.join(directory, "graph.py"), "w", encoding="utf-8") as file:
+ file.write(graph)
+
+ with open(os.path.join(directory, "prompt.py"), "w", encoding="utf-8") as file:
+ file.write(response["prompt"])
+
+ with open(os.path.join(directory, "__init__.py"), "w", encoding="utf-8") as file:
+ file.write("")
diff --git a/metagpt/ext/aflow/scripts/prompts/optimize_prompt.py b/metagpt/ext/aflow/scripts/prompts/optimize_prompt.py
new file mode 100644
index 000000000..231506a37
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/prompts/optimize_prompt.py
@@ -0,0 +1,59 @@
+WORKFLOW_OPTIMIZE_PROMPT = """You are building a Graph and corresponding Prompt to jointly solve {type} problems.
+Referring to the given graph and prompt, which forms a basic example of a {type} solution approach,
+please reconstruct and optimize them. You can add, modify, or delete nodes, parameters, or prompts. Include your
+single modification in XML tags in your reply. Ensure they are complete and correct to avoid runtime failures. When
+optimizing, you can incorporate critical thinking methods like review, revise, ensemble (generating multiple answers through different/similar prompts, then voting/integrating/checking the majority to obtain a final answer), selfAsk, etc. Consider
+Python's loops (for, while, list comprehensions), conditional statements (if-elif-else, ternary operators),
+or machine learning techniques (e.g., linear regression, decision trees, neural networks, clustering). The graph
+complexity should not exceed 10. Use logical and control flow (IF-ELSE, loops) for a more enhanced graphical
+representation.Ensure that all the prompts required by the current graph from prompt_custom are included.Exclude any other prompts.
+Output the modified graph and all the necessary Prompts in prompt_custom (if needed).
+The prompt you need to generate is only the one used in `prompt_custom.XXX` within Custom. Other methods already have built-in prompts and are prohibited from being generated. Only generate those needed for use in `prompt_custom`; please remove any unused prompts in prompt_custom.
+the generated prompt must not contain any placeholders.
+Considering information loss, complex graphs may yield better results, but insufficient information transmission can omit the solution. It's crucial to include necessary context during the process."""
+
+
+WORKFLOW_INPUT = """
+Here is a graph and the corresponding prompt (prompt only related to the custom method) that performed excellently in a previous iteration (maximum score is 1). You must make further optimizations and improvements based on this graph. The modified graph must differ from the provided example, and the specific differences should be noted within the xxx section.\n
+
+ {experience}
+ (such as:add a review step/delete a operator/modify a prompt)
+ {score}
+ {graph}
+ {prompt}(only prompt_custom)
+ {operator_description}
+
+Below are the logs of some results with the aforementioned Graph that performed well but encountered errors, which can be used as references for optimization:
+{log}
+
+First, provide optimization ideas. **Only one detail point can be modified at a time**, and no more than 5 lines of code may be changed per modification—extensive modifications are strictly prohibited to maintain project focus!
+When introducing new functionalities in the graph, please make sure to import the necessary libraries or modules yourself, except for operator, prompt_custom, create_llm_instance, and CostManage, which have already been automatically imported.
+**Under no circumstances should Graph output None for any field.**
+Use custom methods to restrict your output format, rather than using code (outside of the code, the system will extract answers based on certain rules and score them).
+It is very important to format the Graph output answers, you can refer to the standard answer format in the log.
+"""
+
+WORKFLOW_CUSTOM_USE = """\nHere's an example of using the `custom` method in graph:
+```
+# You can write your own prompt in prompt_custom and then use it in the Custom method in the graph
+response = await self.custom(input=problem, instruction=prompt_custom.XXX_PROMPT)
+# You can also concatenate previously generated string results in the input to provide more comprehensive contextual information.
+# response = await self.custom(input=problem+f"xxx:{xxx}, xxx:{xxx}", instruction=prompt_custom.XXX_PROMPT)
+# The output from the Custom method can be placed anywhere you need it, as shown in the example below
+solution = await self.generate(problem=f"question:{problem}, xxx:{response['response']}")
+```
+Note: In custom, the input and instruction are directly concatenated(instruction+input), and placeholders are not supported. Please ensure to add comments and handle the concatenation externally.\n
+
+**Introducing multiple operators at appropriate points can enhance performance. If you find that some provided operators are not yet used in the graph, try incorporating them.**
+"""
+
+WORKFLOW_TEMPLATE = """from typing import Literal
+import metagpt.ext.aflow.scripts.optimized.{dataset}.workflows.template.operator as operator
+import metagpt.ext.aflow.scripts.optimized.{dataset}.workflows.round_{round}.prompt as prompt_custom
+from metagpt.provider.llm_provider_registry import create_llm_instance
+from metagpt.utils.cost_manager import CostManager
+
+DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"]
+
+{graph}
+"""
diff --git a/metagpt/ext/aflow/scripts/prompts/prompt.py b/metagpt/ext/aflow/scripts/prompts/prompt.py
new file mode 100644
index 000000000..16bf78af8
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/prompts/prompt.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# @Date : 6/26/2024 17:07 PM
+# @Author : didi
+# @Desc : prompts of operators
+
+ANSWER_GENERATION_PROMPT = """
+Think step by step and solve the problem.
+1. In the "thought" field, explain your thinking process in detail.
+2. In the "answer" field, provide the final answer concisely and clearly. The answer should be a direct response to the question, without including explanations or reasoning.
+Your task: {input}
+"""
+
+FORMAT_PROMPT = """
+For the question described as {problem_description},
+please extract a short and concise answer contains only one word/few words from the following solution: {solution}.
+Make sure there are no additional comments or explanations in your response.
+"""
+
+SC_ENSEMBLE_PROMPT = """
+Given the question described as follows: {question}
+Several solutions have been generated to address the given question. They are as follows:
+{solutions}
+
+Carefully evaluate these solutions and identify the answer that appears most frequently across them. This consistency in answers is crucial for determining the most reliable solution.
+
+In the "thought" field, provide a detailed explanation of your thought process. In the "solution_letter" field, output only the single letter ID (A, B, C, etc.) corresponding to the most consistent solution. Do not include any additional text or explanation in the "solution_letter" field.
+"""
+
+PYTHON_CODE_VERIFIER_PROMPT = """
+You are a professional Python programmer. Your task is to write complete, self-contained code based on a given mathematical problem and output the answer. The code should include all necessary imports and dependencies, and be ready to run without additional setup or environment configuration.
+
+Problem description: {problem}
+Other analysis: {analysis}
+{feedback}
+
+Your code should:
+1. Implement the calculation steps described in the problem.
+2. Define a function named `solve` that performs the calculation and returns the result. The `solve` function should not require any input parameters; instead, it should obtain all necessary inputs from within the function or from globally defined variables.
+3. `solve` function return the final calculation result.
+
+Please ensure your code is efficient, well-commented, and follows Python best practices. The output should be limited to basic data types such as strings, integers, and floats. It is prohibited to transmit images or other file formats. The code output is intended for a text-based language model.
+"""
+
+
+REFLECTION_ON_PUBLIC_TEST_PROMPT = """
+Given a code problem and a python code solution which failed to pass test or execute, you need to analyze the reason for the failure and propose a better code solution.:
+### problem
+{problem}
+
+### Code Solution
+{solution}
+
+### Execution Result
+{exec_pass}
+
+#### Failed Test Case
+{test_fail}
+
+Please provide a reflection on the failed test cases and code solution, followed by a better code solution without any additional text or test cases.
+"""
+
+MD_ENSEMBLE_PROMPT = """
+Given the question described as follows: {question}
+Several solutions have been generated to address the given question. They are as follows:
+{solutions}
+
+Carefully evaluate these solutions and identify the solution that is more capable of solving the problem compared to other solutions, as this is crucial for problem-solving.
+
+In the "thought" field, provide a detailed explanation of your thought process. In the "solution_letter" field, output only the single letter ID (A, B, C, etc.) corresponding to the solution. Do not include any additional text or explanation in the "solution_letter" field.
+"""
+
+REVIEW_PROMPT = """
+Given a problem and a thoughtful solution, your task is to using critical thinking (questioning) to review the solution's correctness and provide a review result in boolean format.
+
+problem: {problem}
+solution: {solution}
+
+If you are more than 95 percent confident that the final answer is incorrect, please return False and give a feedback for the error. Otherwise, please return True and give a explanation for the correctness.
+"""
+
+REVISE_PROMPT = """
+Given a problem and a thoughtful solution which is just reviewed as incorrect, your task is to revise the solution to solve the question and ensure the final code solution is wrapped with ```python```.
+
+problem: {problem}
+solution: {solution}
+feedback: {feedback}
+
+Ensure the output code is self-contained, and without any additional text or test cases.
+"""
diff --git a/metagpt/ext/aflow/scripts/utils.py b/metagpt/ext/aflow/scripts/utils.py
new file mode 100644
index 000000000..5e6222dc4
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/utils.py
@@ -0,0 +1,125 @@
+"""
+@Time : 2024/7/24 16:37
+@Author : didi
+@File : utils.py
+"""
+
+import json
+import re
+from enum import Enum
+from typing import Any, List, Tuple
+
+
+class CodeDataset(Enum):
+ HUMAN_EVAL = "HumanEval"
+ MBPP = "MBPP"
+
+
+def extract_test_cases_from_jsonl(entry_point: str, dataset: CodeDataset = CodeDataset.HUMAN_EVAL):
+ if dataset == CodeDataset.HUMAN_EVAL.value:
+ file_path = "metagpt/ext/aflow/data/humaneval_public_test.jsonl"
+ # Retain the original hardcoded test cases
+ hardcoded_cases = {
+ "find_zero": "",
+ "decode_cyclic": "",
+ "decode_shift": "",
+ "by_length": "",
+ "add": "",
+ "triangle_area": "",
+ "correct_bracketing": "",
+ "solve": "",
+ "sum_squares": "",
+ "starts_one_ends": "",
+ }
+ elif dataset == CodeDataset.MBPP.value:
+ file_path = "metagpt/ext/aflow/data/mbpp_public_test.jsonl"
+ hardcoded_cases = {
+ "remove_odd": "",
+ "replace_spaces": "",
+ "snake_to_camel": "",
+ "Split": "",
+ "swap_List": "",
+ "square_Sum": "",
+ "sort_sublists": "",
+ "unique_sublists": "",
+ }
+ # Check if there are hardcoded test cases
+ if entry_point in hardcoded_cases:
+ return hardcoded_cases[entry_point]
+
+ # If there are no hardcoded test cases, read from the file
+ with open(file_path, "r") as file:
+ for line in file:
+ data = json.loads(line)
+ if data.get("entry_point") == entry_point:
+ return data.get("test")
+
+ return None
+
+
+def extract_test_cases(docstring: str) -> List[Tuple[str, List[Any], Any]]:
+ # Use regular expressions to match test cases, now capturing function names and any output
+ pattern = r">>> (\w+)\((.*?)\)\n\s*(.*?)(?=\n|$)"
+ matches = re.findall(pattern, docstring, re.DOTALL)
+
+ test_cases = []
+ for match in matches:
+ func_name, input_str, expected_output = match
+
+ # Process input
+ input_list = []
+ for item in input_str.split(","):
+ item = item.strip()
+ try:
+ # Try to convert input to numeric type
+ if "." in item:
+ input_list.append(float(item))
+ else:
+ input_list.append(int(item))
+ except ValueError:
+ # If unable to convert to numeric, keep as string
+ input_list.append(item.strip("'\""))
+
+ # Process output
+ try:
+ # Try to convert output to numeric or boolean value
+ if expected_output.lower() == "true":
+ expected_output = True
+ elif expected_output.lower() == "false":
+ expected_output = False
+ elif "." in expected_output:
+ expected_output = float(expected_output)
+ else:
+ expected_output = int(expected_output)
+ except ValueError:
+ # If unable to convert, keep as string
+ expected_output = expected_output.strip("'\"")
+
+ test_cases.append([func_name, input_list, expected_output])
+
+ return test_cases
+
+
+def test_cases_2_test_functions(solution: str, test_cases: str):
+ tester_function = f"""
+{solution}
+
+{test_cases}
+"""
+ return tester_function
+
+
+def test_case_2_test_function(solution: str, test_case: str, entry_point: str):
+ tester_function = f"""
+{solution}
+
+
+def check(candidate):
+ {test_case}
+
+def test_check():
+ check({entry_point})
+
+test_check()
+"""
+ return tester_function
diff --git a/metagpt/ext/aflow/scripts/workflow.py b/metagpt/ext/aflow/scripts/workflow.py
new file mode 100644
index 000000000..47b54021b
--- /dev/null
+++ b/metagpt/ext/aflow/scripts/workflow.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+# @Date : 6/27/2024 22:07 PM
+# @Author : didi
+# @Desc : Basic Graph Class
+
+
+from metagpt.ext.aflow.scripts.evaluator import DatasetType
+from metagpt.provider.llm_provider_registry import create_llm_instance
+from metagpt.utils.cost_manager import CostManager
+
+
+class Workflow:
+ def __init__(
+ self,
+ name: str,
+ llm_config,
+ dataset: DatasetType,
+ ) -> None:
+ self.name = name
+ self.dataset = dataset
+ self.llm = create_llm_instance(llm_config)
+ self.llm.cost_manager = CostManager()
+
+ async def __call__(self, problem: str):
+ """
+ Implementation of the workflow
+ """
+ raise NotImplementedError("This method should be implemented by the subclass")
diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py
index 1236bf56b..c5e3b7bd2 100644
--- a/metagpt/provider/bedrock/bedrock_provider.py
+++ b/metagpt/provider/bedrock/bedrock_provider.py
@@ -57,15 +57,34 @@ class AnthropicProvider(BaseBedrockProvider):
class CohereProvider(BaseBedrockProvider):
- # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
+ # For more information, see
+ # (Command) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
+ # (Command R/R+) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
+
+ def __init__(self, model_name: str) -> None:
+ self.model_name = model_name
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generations"][0]["text"]
+ def messages_to_prompt(self, messages: list[dict]) -> str:
+ if "command-r" in self.model_name:
+ role_map = {"user": "USER", "assistant": "CHATBOT", "system": "USER"}
+ messages = list(
+ map(lambda message: {"role": role_map[message["role"]], "message": message["content"]}, messages)
+ )
+ return messages
+ else:
+ """[{"role": "user", "content": msg}] to user: etc."""
+ return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
+
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
- body = json.dumps(
- {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
- )
+ prompt = self.messages_to_prompt(messages)
+ if "command-r" in self.model_name:
+ chat_history, message = prompt[:-1], prompt[-1]["message"]
+ body = json.dumps({"message": message, "chat_history": chat_history, **generate_kwargs})
+ else:
+ body = json.dumps({"prompt": prompt, "stream": kwargs.get("stream", False), **generate_kwargs})
return body
def get_choice_text_from_stream(self, event) -> str:
@@ -95,10 +114,37 @@ class MetaProvider(BaseBedrockProvider):
class Ai21Provider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
- max_tokens_field_name = "maxTokens"
+ def __init__(self, model_type: Literal["j2", "jamba"]) -> None:
+ self.model_type = model_type
+ if self.model_type == "j2":
+ self.max_tokens_field_name = "maxTokens"
+ else:
+ self.max_tokens_field_name = "max_tokens"
+
+ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
+ if self.model_type == "j2":
+ body = super().get_request_body(messages, generate_kwargs, *args, **kwargs)
+ else:
+ body = json.dumps(
+ {
+ "messages": messages,
+ **generate_kwargs,
+ }
+ )
+ return body
+
+ def get_choice_text_from_stream(self, event) -> str:
+ rsp_dict = json.loads(event["chunk"]["bytes"])
+ completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
+ return completions
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
- return rsp_dict["completions"][0]["data"]["text"]
+ if self.model_type == "j2":
+ # See https://docs.ai21.com/reference/j2-complete-ref
+ return rsp_dict["completions"][0]["data"]["text"]
+ else:
+ # See https://docs.ai21.com/reference/jamba-instruct-api
+ return rsp_dict["choices"][0]["message"]["content"]
class AmazonProvider(BaseBedrockProvider):
@@ -136,4 +182,10 @@ def get_provider(model_id: str):
if provider == "meta":
# distinguish llama2 and llama3
return PROVIDERS[provider](model_name[:6])
+ elif provider == "ai21":
+ # distinguish between j2 and jamba
+ return PROVIDERS[provider](model_name.split("-")[0])
+ elif provider == "cohere":
+ # distinguish between R/R+ and older models
+ return PROVIDERS[provider](model_name)
return PROVIDERS[provider]()
diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py
index 46520d1d5..e67796362 100644
--- a/metagpt/provider/bedrock/utils.py
+++ b/metagpt/provider/bedrock/utils.py
@@ -1,52 +1,97 @@
from metagpt.logs import logger
# max_tokens for each model
-NOT_SUUPORT_STREAM_MODELS = {
- "ai21.j2-grande-instruct": 8000,
- "ai21.j2-jumbo-instruct": 8000,
- "ai21.j2-mid": 8000,
- "ai21.j2-mid-v1": 8000,
- "ai21.j2-ultra": 8000,
- "ai21.j2-ultra-v1": 8000,
+NOT_SUPPORT_STREAM_MODELS = {
+ # Jurassic-2 Mid-v1 and Ultra-v1
+ # + Legacy date: 2024-04-30 (us-west-2/Oregon)
+ # + EOL date: 2024-08-31 (us-west-2/Oregon)
+ "ai21.j2-mid-v1": 8191,
+ "ai21.j2-ultra-v1": 8191,
}
SUPPORT_STREAM_MODELS = {
- "amazon.titan-tg1-large": 8000,
- "amazon.titan-text-express-v1": 8000,
- "amazon.titan-text-express-v1:0:8k": 8000,
- "amazon.titan-text-lite-v1:0:4k": 4000,
- "amazon.titan-text-lite-v1": 4000,
- "anthropic.claude-instant-v1": 100000,
- "anthropic.claude-instant-v1:2:100k": 100000,
- "anthropic.claude-v1": 100000,
- "anthropic.claude-v2": 100000,
- "anthropic.claude-v2:1": 200000,
- "anthropic.claude-v2:0:18k": 18000,
- "anthropic.claude-v2:1:200k": 200000,
- "anthropic.claude-3-sonnet-20240229-v1:0": 200000,
- "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
- "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
- "anthropic.claude-3-haiku-20240307-v1:0": 200000,
- "anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
- "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
- "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
- # currently (2024-4-29) only available at US West (Oregon) AWS Region.
- "anthropic.claude-3-opus-20240229-v1:0": 200000,
- "cohere.command-text-v14": 4000,
- "cohere.command-text-v14:7:4k": 4000,
- "cohere.command-light-text-v14": 4000,
- "cohere.command-light-text-v14:7:4k": 4000,
- "meta.llama2-13b-chat-v1:0:4k": 4000,
- "meta.llama2-13b-chat-v1": 2000,
- "meta.llama2-70b-v1": 4000,
- "meta.llama2-70b-v1:0:4k": 4000,
- "meta.llama2-70b-chat-v1": 2000,
- "meta.llama2-70b-chat-v1:0:4k": 2000,
- "meta.llama3-8b-instruct-v1:0": 2000,
- "meta.llama3-70b-instruct-v1:0": 2000,
- "mistral.mistral-7b-instruct-v0:2": 32000,
- "mistral.mixtral-8x7b-instruct-v0:1": 32000,
- "mistral.mistral-large-2402-v1:0": 32000,
+ # Jamba-Instruct
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jamba.html
+ "ai21.jamba-instruct-v1:0": 4096,
+ # Titan Text G1 - Lite
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
+ "amazon.titan-text-lite-v1:0:4k": 4096,
+ "amazon.titan-text-lite-v1": 4096,
+ # Titan Text G1 - Express
+ "amazon.titan-text-express-v1": 8192,
+ "amazon.titan-text-express-v1:0:8k": 8192,
+ # Titan Text Premier
+ "amazon.titan-text-premier-v1:0": 3072,
+ "amazon.titan-text-premier-v1:0:32k": 3072,
+ # Claude Instant v1
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
+ # https://docs.anthropic.com/en/docs/about-claude/models#model-comparison
+ "anthropic.claude-instant-v1": 4096,
+ "anthropic.claude-instant-v1:2:100k": 4096,
+ # Claude v2
+ "anthropic.claude-v2": 4096,
+ "anthropic.claude-v2:0:18k": 4096,
+ "anthropic.claude-v2:0:100k": 4096,
+ # Claude v2.1
+ "anthropic.claude-v2:1": 4096,
+ "anthropic.claude-v2:1:18k": 4096,
+ "anthropic.claude-v2:1:200k": 4096,
+ # Claude 3 Sonnet
+ "anthropic.claude-3-sonnet-20240229-v1:0": 4096,
+ "anthropic.claude-3-sonnet-20240229-v1:0:28k": 4096,
+ "anthropic.claude-3-sonnet-20240229-v1:0:200k": 4096,
+ # Claude 3 Haiku
+ "anthropic.claude-3-haiku-20240307-v1:0": 4096,
+ "anthropic.claude-3-haiku-20240307-v1:0:48k": 4096,
+ "anthropic.claude-3-haiku-20240307-v1:0:200k": 4096,
+ # Claude 3 Opus
+ "anthropic.claude-3-opus-20240229-v1:0": 4096,
+ # Claude 3.5 Sonnet
+ "anthropic.claude-3-5-sonnet-20240620-v1:0": 8192,
+ # Command Text
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
+ "cohere.command-text-v14": 4096,
+ "cohere.command-text-v14:7:4k": 4096,
+ # Command Light Text
+ "cohere.command-light-text-v14": 4096,
+ "cohere.command-light-text-v14:7:4k": 4096,
+ # Command R
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
+ "cohere.command-r-v1:0": 4096,
+ # Command R+
+ "cohere.command-r-plus-v1:0": 4096,
+ # Llama 2 (--> Llama 3/3.1/3.2) !!!
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
+ # + Legacy: 2024-05-12
+ # + EOL: 2024-10-30
+ # "meta.llama2-13b-chat-v1": 2048,
+ # "meta.llama2-13b-chat-v1:0:4k": 2048,
+ # "meta.llama2-70b-v1": 2048,
+ # "meta.llama2-70b-v1:0:4k": 2048,
+ # "meta.llama2-70b-chat-v1": 2048,
+ # "meta.llama2-70b-chat-v1:0:4k": 2048,
+ # Llama 3 Instruct
+ # "meta.llama3-8b-instruct-v1:0": 2048,
+ "meta.llama3-70b-instruct-v1:0": 2048,
+ # Llama 3.1 Instruct
+ # "meta.llama3-1-8b-instruct-v1:0": 2048,
+ "meta.llama3-1-70b-instruct-v1:0": 2048,
+ "meta.llama3-1-405b-instruct-v1:0": 2048,
+ # Llama 3.2 Instruct
+ # "meta.llama3-2-3b-instruct-v1:0": 2048,
+ # "meta.llama3-2-11b-instruct-v1:0": 2048,
+ "meta.llama3-2-90b-instruct-v1:0": 2048,
+ # Mistral 7B Instruct
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html
+ # "mistral.mistral-7b-instruct-v0:2": 8192,
+ # Mixtral 8x7B Instruct
+ "mistral.mixtral-8x7b-instruct-v0:1": 4096,
+ # Mistral Small
+ "mistral.mistral-small-2402-v1:0": 8192,
+ # Mistral Large (24.02)
+ "mistral.mistral-large-2402-v1:0": 8192,
+ # Mistral Large 2 (24.07)
+ "mistral.mistral-large-2407-v1:0": 8192,
}
# TODO:use a more general function for constructing chat templates.
@@ -106,7 +151,7 @@ def messages_to_prompt_claude2(messages: list[dict]) -> str:
def get_max_tokens(model_id: str) -> int:
try:
- max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
+ max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
except KeyError:
logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048")
max_tokens = 2048
diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py
index 03954e5c2..4e783f579 100644
--- a/metagpt/provider/bedrock_api.py
+++ b/metagpt/provider/bedrock_api.py
@@ -1,5 +1,6 @@
import asyncio
import json
+import os
from functools import partial
from typing import List, Literal
@@ -11,7 +12,7 @@ from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.bedrock.bedrock_provider import get_provider
-from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
+from metagpt.provider.bedrock.utils import NOT_SUPPORT_STREAM_MODELS, get_max_tokens
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS
@@ -24,18 +25,19 @@ class BedrockLLM(BaseLLM):
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
- if self.config.model in NOT_SUUPORT_STREAM_MODELS:
+ if self.config.model in NOT_SUPPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
"""initialize boto3 client"""
# access key and secret key from https://us-east-1.console.aws.amazon.com/iam
- self.__credentital_kwargs = {
- "aws_secret_access_key": self.config.secret_key,
- "aws_access_key_id": self.config.access_key,
- "region_name": self.config.region_name,
+ self.__credential_kwargs = {
+ "aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY", self.config.secret_key),
+ "aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID", self.config.access_key),
+ "aws_session_token": os.environ.get("AWS_SESSION_TOKEN", self.config.session_token),
+ "region_name": os.environ.get("AWS_DEFAULT_REGION", self.config.region_name),
}
- session = boto3.Session(**self.__credentital_kwargs)
+ session = boto3.Session(**self.__credential_kwargs)
client = session.client(service_name)
return client
@@ -111,7 +113,7 @@ class BedrockLLM(BaseLLM):
return await self.acompletion(messages)
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
- if self.config.model in NOT_SUUPORT_STREAM_MODELS:
+ if self.config.model in NOT_SUPPORT_STREAM_MODELS:
rsp = await self.acompletion(messages)
full_text = self.get_choice_text(rsp)
log_llm_stream(full_text)
diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py
index 8e5da8f16..34a39fe6c 100644
--- a/metagpt/provider/general_api_base.py
+++ b/metagpt/provider/general_api_base.py
@@ -13,6 +13,7 @@ import time
from contextlib import asynccontextmanager
from enum import Enum
from typing import (
+ Any,
AsyncGenerator,
AsyncIterator,
Dict,
@@ -121,7 +122,7 @@ def logfmt(props):
class OpenAIResponse:
- def __init__(self, data, headers):
+ def __init__(self, data: Union[bytes, Any], headers: dict):
self._headers = headers
self.data = data
@@ -320,49 +321,6 @@ class APIRequestor:
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
- @overload
- async def arequest(
- self,
- method,
- url,
- params,
- headers,
- files,
- stream: Literal[True],
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
- pass
-
- @overload
- async def arequest(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- *,
- stream: Literal[True],
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
- pass
-
- @overload
- async def arequest(
- self,
- method,
- url,
- params=...,
- headers=...,
- files=...,
- stream: Literal[False] = ...,
- request_id: Optional[str] = ...,
- request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
- ) -> Tuple[OpenAIResponse, bool, str]:
- pass
-
@overload
async def arequest(
self,
@@ -438,8 +396,8 @@ class APIRequestor:
"X-LLM-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
}
-
- headers.update(api_key_to_header(self.api_type, self.api_key))
+ if self.api_key:
+ headers.update(api_key_to_header(self.api_type, self.api_key))
if self.organization:
headers["LLM-Organization"] = self.organization
diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py
index 501a064e3..b8da1565d 100644
--- a/metagpt/provider/general_api_requestor.py
+++ b/metagpt/provider/general_api_requestor.py
@@ -3,25 +3,24 @@
# @Desc : General Async API for http-based LLM model
import asyncio
-from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
+from typing import AsyncGenerator, Iterator, Optional, Tuple, Union
import aiohttp
import requests
from metagpt.logs import logger
-from metagpt.provider.general_api_base import APIRequestor
+from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse
-def parse_stream_helper(line: bytes) -> Union[bytes, None]:
+def parse_stream_helper(line: bytes) -> Optional[bytes]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
- # SSE event may be valid when it contain whitespace
+ # SSE event may be valid when it contains whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
- # return here will cause GeneratorExit exception in urllib3
- # and it will close http connection with TCP Reset
+ # Returning None to indicate end of stream
return None
else:
return line
@@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:
class GeneralAPIRequestor(APIRequestor):
"""
- usage
+ Usage example:
# full_url = "{base_url}{url}"
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
@@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor):
)
"""
- def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
- # just do nothing to meet the APIRequestor process and return the raw data
- # due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
+ def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse:
+ """
+ Process and return the response data wrapped in OpenAIResponse.
- return rbody
+ Args:
+ rbody (bytes): The response body.
+ rcode (int): The response status code.
+ rheaders (dict): The response headers.
+ stream (bool): Whether the response is a stream.
+
+ Returns:
+ OpenAIResponse: The response data wrapped in OpenAIResponse.
+ """
+ return OpenAIResponse(rbody, rheaders)
def _interpret_response(
self, result: requests.Response, stream: bool
- ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
- """Returns the response(s) and a bool indicating whether it is a stream."""
+ ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
+ """
+ Interpret a synchronous response.
+
+ Args:
+ result (requests.Response): The response object.
+ stream (bool): Whether the response is a stream.
+
+ Returns:
+ Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
+ """
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
return (
- self._interpret_response_line(line, result.status_code, result.headers, stream=True)
- for line in parse_stream(result.iter_lines())
- ), True
+ (
+ self._interpret_response_line(line, result.status_code, result.headers, stream=True)
+ for line in parse_stream(result.iter_lines())
+ ),
+ True,
+ )
else:
return (
self._interpret_response_line(
- result.content, # let the caller to decode the msg
+ result.content, # let the caller decode the msg
result.status_code,
result.headers,
stream=False,
@@ -79,26 +99,39 @@ class GeneralAPIRequestor(APIRequestor):
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
- ) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
+ ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
+ """
+ Interpret an asynchronous response.
+
+ Args:
+ result (aiohttp.ClientResponse): The response object.
+ stream (bool): Whether the response is a stream.
+
+ Returns:
+ Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
+ """
content_type = result.headers.get("Content-Type", "")
if stream and (
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
):
- # the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
- self._interpret_response_line(line, result.status, result.headers, stream=True)
- async for line in result.content
- ), True
+ (
+ self._interpret_response_line(line, result.status, result.headers, stream=True)
+ async for line in result.content
+ ),
+ True,
+ )
else:
try:
- await result.read()
+ response_content = await result.read()
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise TimeoutError("Request timed out") from e
except aiohttp.ClientError as exp:
- logger.warning(f"response: {result.content}, exp: {exp}")
+ logger.warning(f"response: {result}, exp: {exp}")
+ response_content = b""
return (
self._interpret_response_line(
- await result.read(), # let the caller to decode the msg
+ response_content, # let the caller decode the msg
result.status,
result.headers,
stream=False,
diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py
index 454f0e3ee..3f7d20d0a 100644
--- a/metagpt/provider/ollama_api.py
+++ b/metagpt/provider/ollama_api.py
@@ -3,16 +3,189 @@
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
import json
+from enum import Enum, auto
+from typing import AsyncGenerator, Optional, Tuple
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
-from metagpt.provider.general_api_requestor import GeneralAPIRequestor
+from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import TokenCostManager
+class OllamaMessageAPI(Enum):
+ # default
+ CHAT = auto()
+ GENERATE = auto()
+ EMBED = auto()
+ EMBEDDINGS = auto()
+
+
+class OllamaMessageBase:
+ api_type = OllamaMessageAPI.CHAT
+
+ def __init__(self, model: str, **additional_kwargs) -> None:
+ self.model, self.additional_kwargs = model, additional_kwargs
+ self._image_b64_rms = len("data:image/jpeg;base64,")
+
+ @property
+ def api_suffix(self) -> str:
+ raise NotImplementedError
+
+ def apply(self, messages: list[dict]) -> dict:
+ raise NotImplementedError
+
+ def decode(self, response: OpenAIResponse) -> dict:
+ return json.loads(response.data.decode("utf-8"))
+
+ def get_choice(self, to_choice_dict: dict) -> str:
+ raise NotImplementedError
+
+ def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]:
+ if "type" in msg:
+ tpe = msg["type"]
+ if tpe == "text":
+ return msg["text"], None
+ elif tpe == "image_url":
+ return None, msg["image_url"]["url"][self._image_b64_rms :]
+ else:
+ raise ValueError
+ else:
+ raise ValueError
+
+
+class OllamaMessageMeta(type):
+ registed_message = {}
+
+ def __init__(cls, name, bases, attrs):
+ super().__init__(name, bases, attrs)
+ for base in bases:
+ if issubclass(base, OllamaMessageBase):
+ api_type = attrs["api_type"]
+ assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist"
+ assert isinstance(api_type, OllamaMessageAPI), "api_type not support"
+ OllamaMessageMeta.registed_message[api_type] = cls
+
+ @classmethod
+ def get_message(cls, input_type: OllamaMessageAPI) -> type[OllamaMessageBase]:
+ return cls.registed_message[input_type]
+
+
+class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta):
+ api_type = OllamaMessageAPI.CHAT
+
+ @property
+ def api_suffix(self) -> str:
+ return "/chat"
+
+ def apply(self, messages: list[dict]) -> dict:
+ content = messages[0]["content"]
+ prompts = []
+ images = []
+ if isinstance(content, list):
+ for msg in content:
+ prompt, image = self._parse_input_msg(msg)
+ if prompt:
+ prompts.append(prompt)
+ if image:
+ images.append(image)
+ else:
+ prompts.append(content)
+ messes = []
+ for prompt in prompts:
+ if len(images) > 0:
+ messes.append({"role": "user", "content": prompt, "images": images})
+ else:
+ messes.append({"role": "user", "content": prompt})
+ sends = {"model": self.model, "messages": messes}
+ sends.update(self.additional_kwargs)
+ return sends
+
+ def get_choice(self, to_choice_dict: dict) -> str:
+ message = to_choice_dict["message"]
+ if message["role"] == "assistant":
+ return message["content"]
+ else:
+ raise ValueError
+
+
+class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta):
+ api_type = OllamaMessageAPI.GENERATE
+
+ @property
+ def api_suffix(self) -> str:
+ return "/generate"
+
+ def apply(self, messages: list[dict]) -> dict:
+ content = messages[0]["content"]
+ prompts = []
+ images = []
+ if isinstance(content, list):
+ for msg in content:
+ prompt, image = self._parse_input_msg(msg)
+ if prompt:
+ prompts.append(prompt)
+ if image:
+ images.append(image)
+ else:
+ prompts.append(content)
+ if len(images) > 0:
+ sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images}
+ else:
+ sends = {"model": self.model, "prompt": "\n".join(prompts)}
+ sends.update(self.additional_kwargs)
+ return sends
+
+ def get_choice(self, to_choice_dict: dict) -> str:
+ return to_choice_dict["response"]
+
+
+class OllamaMessageEmbeddings(OllamaMessageBase, metaclass=OllamaMessageMeta):
+ api_type = OllamaMessageAPI.EMBEDDINGS
+
+ @property
+ def api_suffix(self) -> str:
+ return "/embeddings"
+
+ def apply(self, messages: list[dict]) -> dict:
+ content = messages[0]["content"]
+ prompts = [] # NOTE: not support image to embedding
+ if isinstance(content, list):
+ for msg in content:
+ prompt, _ = self._parse_input_msg(msg)
+ if prompt:
+ prompts.append(prompt)
+ else:
+ prompts.append(content)
+ sends = {"model": self.model, "prompt": "\n".join(prompts)}
+ sends.update(self.additional_kwargs)
+ return sends
+
+
+class OllamaMessageEmbed(OllamaMessageEmbeddings, metaclass=OllamaMessageMeta):
+ api_type = OllamaMessageAPI.EMBED
+
+ @property
+ def api_suffix(self) -> str:
+ return "/embed"
+
+ def apply(self, messages: list[dict]) -> dict:
+ content = messages[0]["content"]
+ prompts = [] # NOTE: not support image to embedding
+ if isinstance(content, list):
+ for msg in content:
+ prompt, _ = self._parse_input_msg(msg)
+ if prompt:
+ prompts.append(prompt)
+ else:
+ prompts.append(content)
+ sends = {"model": self.model, "input": prompts}
+ sends.update(self.additional_kwargs)
+ return sends
+
+
@register_provider(LLMType.OLLAMA)
class OllamaLLM(BaseLLM):
"""
@@ -20,83 +193,80 @@ class OllamaLLM(BaseLLM):
"""
def __init__(self, config: LLMConfig):
- self.__init_ollama(config)
self.client = GeneralAPIRequestor(base_url=config.base_url)
self.config = config
- self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self.cost_manager = TokenCostManager()
+ self.__init_ollama(config)
+
+ @property
+ def _llama_api_inuse(self) -> OllamaMessageAPI:
+ return OllamaMessageAPI.CHAT
+
+ @property
+ def _llama_api_kwargs(self) -> dict:
+ return {"options": {"temperature": 0.3}, "stream": self.config.stream}
def __init_ollama(self, config: LLMConfig):
assert config.base_url, "ollama base url is required!"
self.model = config.model
self.pricing_plan = self.model
-
- def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
- kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
- return kwargs
-
- def get_choice_text(self, resp: dict) -> str:
- """get the resp content from llm response"""
- assist_msg = resp.get("message", {})
- assert assist_msg.get("role", None) == "assistant"
- return assist_msg.get("content")
+ ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse)
+ self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs)
def get_usage(self, resp: dict) -> dict:
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
- def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
- chunk = chunk.decode(encoding)
- return json.loads(chunk)
-
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
- headers = (
- None
- if not self.config.api_key or self.config.api_key == "sk-"
- else {
- "Authorization": f"Bearer {self.config.api_key}",
- }
- )
resp, _, _ = await self.client.arequest(
method=self.http_method,
- url=self.suffix_url,
- headers=headers,
- params=self._const_kwargs(messages),
+ url=self.ollama_message.api_suffix,
+ params=self.ollama_message.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
)
- resp = self._decode_and_load(resp)
- usage = self.get_usage(resp)
- self._update_costs(usage)
- return resp
+ if isinstance(resp, AsyncGenerator):
+ return await self._processing_openai_response_async_generator(resp)
+ elif isinstance(resp, OpenAIResponse):
+ return self._processing_openai_response(resp)
+ else:
+ raise ValueError
+
+ def get_choice_text(self, rsp):
+ return self.ollama_message.get_choice(rsp)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
- headers = (
- None
- if not self.config.api_key or self.config.api_key == "sk-"
- else {
- "Authorization": f"Bearer {self.config.api_key}",
- }
- )
- stream_resp, _, _ = await self.client.arequest(
+ resp, _, _ = await self.client.arequest(
method=self.http_method,
- url=self.suffix_url,
- headers=headers,
- stream=True,
- params=self._const_kwargs(messages, stream=True),
+ url=self.ollama_message.api_suffix,
+ params=self.ollama_message.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
+ stream=True,
)
+ if isinstance(resp, AsyncGenerator):
+ return await self._processing_openai_response_async_generator(resp)
+ elif isinstance(resp, OpenAIResponse):
+ return self._processing_openai_response(resp)
+ else:
+ raise ValueError
+ def _processing_openai_response(self, openai_resp: OpenAIResponse):
+ resp = self.ollama_message.decode(openai_resp)
+ usage = self.get_usage(resp)
+ self._update_costs(usage)
+ return resp
+
+ async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]):
collected_content = []
usage = {}
- async for raw_chunk in stream_resp:
- chunk = self._decode_and_load(raw_chunk)
+ async for raw_chunk in ag_openai_resp:
+ chunk = self.ollama_message.decode(raw_chunk)
if not chunk.get("done", False):
- content = self.get_choice_text(chunk)
+ content = self.ollama_message.get_choice(chunk)
collected_content.append(content)
log_llm_stream(content)
else:
@@ -107,3 +277,55 @@ class OllamaLLM(BaseLLM):
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
+
+
+@register_provider(LLMType.OLLAMA_GENERATE)
+class OllamaGenerate(OllamaLLM):
+ @property
+ def _llama_api_inuse(self) -> OllamaMessageAPI:
+ return OllamaMessageAPI.GENERATE
+
+ @property
+ def _llama_api_kwargs(self) -> dict:
+ return {"options": {"temperature": 0.3}, "stream": self.config.stream}
+
+
+@register_provider(LLMType.OLLAMA_EMBEDDINGS)
+class OllamaEmbeddings(OllamaLLM):
+ @property
+ def _llama_api_inuse(self) -> OllamaMessageAPI:
+ return OllamaMessageAPI.EMBEDDINGS
+
+ @property
+ def _llama_api_kwargs(self) -> dict:
+ return {"options": {"temperature": 0.3}}
+
+ @property
+ def _llama_embedding_key(self) -> str:
+ return "embedding"
+
+ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
+ resp, _, _ = await self.client.arequest(
+ method=self.http_method,
+ url=self.ollama_message.api_suffix,
+ params=self.ollama_message.apply(messages=messages),
+ request_timeout=self.get_timeout(timeout),
+ )
+ return self.ollama_message.decode(resp)[self._llama_embedding_key]
+
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
+ return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
+
+ def get_choice_text(self, rsp):
+ return rsp
+
+
+@register_provider(LLMType.OLLAMA_EMBED)
+class OllamaEmbed(OllamaEmbeddings):
+ @property
+ def _llama_api_inuse(self) -> OllamaMessageAPI:
+ return OllamaMessageAPI.EMBED
+
+ @property
+ def _llama_embedding_key(self) -> str:
+ return "embeddings"
diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py
index ce3a06ec8..1d2057e50 100644
--- a/metagpt/provider/openai_api.py
+++ b/metagpt/provider/openai_api.py
@@ -103,7 +103,7 @@ class OpenAILLM(BaseLLM):
if has_finished:
# for oneapi, there has a usage chunk after finish_reason not none chunk
if chunk_has_usage:
- usage = CompletionUsage(**chunk.usage)
+ usage = CompletionUsage(**chunk.usage) if isinstance(chunk.usage, dict) else chunk.usage
if finish_reason:
if chunk_has_usage:
# Some services have usage as an attribute of the chunk, such as Fireworks
diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py
index 6da4900a0..4e6d6b167 100644
--- a/metagpt/rag/factories/index.py
+++ b/metagpt/rag/factories/index.py
@@ -30,7 +30,7 @@ class RAGIndexFactory(ConfigBasedFactory):
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
- MilvusIndexConfig: self._create_milvus
+ MilvusIndexConfig: self._create_milvus,
}
super().__init__(creators)
diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py
index 9fd19cab5..c1069cc6c 100644
--- a/metagpt/rag/factories/llm.py
+++ b/metagpt/rag/factories/llm.py
@@ -23,10 +23,12 @@ class RAGLLM(CustomLLM):
"""LlamaIndex's LLM is different from MetaGPT's LLM.
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
+
+ Set context_length or max_token of LLM in config.yaml if you encounter "Calculated available context size -xxx was not non-negative" error.
"""
model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.")
- context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
+ context_window: int = config.llm.context_length or TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
num_output: int = config.llm.max_token
model_name: str = config.llm.model
diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py
index 3342b8905..490df4906 100644
--- a/metagpt/rag/factories/retriever.py
+++ b/metagpt/rag/factories/retriever.py
@@ -139,7 +139,9 @@ class RetrieverFactory(ConfigBasedFactory):
@get_or_build_index
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
- vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions)
+ vector_store = MilvusVectorStore(
+ uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions
+ )
return self._build_index_from_vector_store(config, vector_store, **kwargs)
diff --git a/metagpt/rag/retrievers/milvus_retriever.py b/metagpt/rag/retrievers/milvus_retriever.py
index ff2562bd8..bcc66330b 100644
--- a/metagpt/rag/retrievers/milvus_retriever.py
+++ b/metagpt/rag/retrievers/milvus_retriever.py
@@ -14,4 +14,4 @@ class MilvusRetriever(VectorIndexRetriever):
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
- Milvus automatically saves, so there is no need to implement."""
\ No newline at end of file
+ Milvus automatically saves, so there is no need to implement."""
diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py
index e4d97068d..1e04a546f 100644
--- a/metagpt/rag/schema.py
+++ b/metagpt/rag/schema.py
@@ -8,7 +8,7 @@ from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
-from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator
+from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
@@ -199,6 +199,7 @@ class ChromaIndexConfig(VectorIndexConfig):
default=None, description="Optional metadata to associate with the collection"
)
+
class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""
diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py
index cf490084d..47f2768cd 100644
--- a/metagpt/utils/common.py
+++ b/metagpt/utils/common.py
@@ -581,6 +581,30 @@ def write_json_file(json_file: str, data: list, encoding: str = None, indent: in
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
+def read_jsonl_file(jsonl_file: str, encoding="utf-8") -> list[dict]:
+ if not Path(jsonl_file).exists():
+ raise FileNotFoundError(f"json_file: {jsonl_file} not exist, return []")
+ datas = []
+ with open(jsonl_file, "r", encoding=encoding) as fin:
+ try:
+ for line in fin:
+ data = json.loads(line)
+ datas.append(data)
+ except Exception:
+ raise ValueError(f"read jsonl file: {jsonl_file} failed")
+ return datas
+
+
+def add_jsonl_file(jsonl_file: str, data: list[dict], encoding: str = None):
+ folder_path = Path(jsonl_file).parent
+ if not folder_path.exists():
+ folder_path.mkdir(parents=True, exist_ok=True)
+
+ with open(jsonl_file, "a", encoding=encoding) as fout:
+ for json_item in data:
+ fout.write(json.dumps(json_item) + "\n")
+
+
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
"""
Reads in a csv file to a list of list. If header is True, it returns a
diff --git a/metagpt/utils/sanitize.py b/metagpt/utils/sanitize.py
new file mode 100644
index 000000000..a9becbb98
--- /dev/null
+++ b/metagpt/utils/sanitize.py
@@ -0,0 +1,183 @@
+"""
+@Time : 2024/7/24 16:37
+@Author : didi
+@File : utils.py
+@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
+"""
+
+import ast
+import traceback
+from enum import Enum
+from typing import Dict, Generator, List, Optional, Set, Tuple
+
+import tree_sitter_python
+from tree_sitter import Language, Node, Parser
+
+
+class NodeType(Enum):
+ CLASS = "class_definition"
+ FUNCTION = "function_definition"
+ IMPORT = ["import_statement", "import_from_statement"]
+ IDENTIFIER = "identifier"
+ ATTRIBUTE = "attribute"
+ RETURN = "return_statement"
+ EXPRESSION = "expression_statement"
+ ASSIGNMENT = "assignment"
+
+
+def traverse_tree(node: Node) -> Generator[Node, None, None]:
+ """
+ Traverse the tree structure starting from the given node.
+
+ :param node: The root node to start the traversal from.
+ :return: A generator object that yields nodes in the tree.
+ """
+ cursor = node.walk()
+ depth = 0
+
+ visited_children = False
+ while True:
+ if not visited_children:
+ yield cursor.node
+ if not cursor.goto_first_child():
+ depth += 1
+ visited_children = True
+ elif cursor.goto_next_sibling():
+ visited_children = False
+ elif not cursor.goto_parent() or depth == 0:
+ break
+ else:
+ depth -= 1
+
+
+def syntax_check(code, verbose=False):
+ try:
+ ast.parse(code)
+ return True
+ except (SyntaxError, MemoryError):
+ if verbose:
+ traceback.print_exc()
+ return False
+
+
+def code_extract(text: str) -> str:
+ lines = text.split("\n")
+ longest_line_pair = (0, 0)
+ longest_so_far = 0
+
+ for i in range(len(lines)):
+ for j in range(i + 1, len(lines)):
+ current_lines = "\n".join(lines[i : j + 1])
+ if syntax_check(current_lines):
+ current_length = sum(1 for line in lines[i : j + 1] if line.strip())
+ if current_length > longest_so_far:
+ longest_so_far = current_length
+ longest_line_pair = (i, j)
+
+ return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
+
+
+def get_definition_name(node: Node) -> str:
+ for child in node.children:
+ if child.type == NodeType.IDENTIFIER.value:
+ return child.text.decode("utf8")
+
+
+def has_return_statement(node: Node) -> bool:
+ traverse_nodes = traverse_tree(node)
+ for node in traverse_nodes:
+ if node.type == NodeType.RETURN.value:
+ return True
+ return False
+
+
+def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
+ def dfs_get_deps(node: Node, deps: Set[str]) -> None:
+ for child in node.children:
+ if child.type == NodeType.IDENTIFIER.value:
+ deps.add(child.text.decode("utf8"))
+ else:
+ dfs_get_deps(child, deps)
+
+ name2deps = {}
+ for name, node in nodes:
+ deps = set()
+ dfs_get_deps(node, deps)
+ name2deps[name] = deps
+ return name2deps
+
+
+def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
+ queue = [entrypoint]
+ visited = {entrypoint}
+ while queue:
+ current = queue.pop(0)
+ if current not in call_graph:
+ continue
+ for neighbour in call_graph[current]:
+ if neighbour not in visited:
+ visited.add(neighbour)
+ queue.append(neighbour)
+ return visited
+
+
+def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
+ """
+ Sanitize and extract relevant parts of the given Python code.
+ This function parses the input code, extracts import statements, class and function definitions,
+ and variable assignments. If an entrypoint is provided, it only includes definitions that are
+ reachable from the entrypoint in the call graph.
+
+ :param code: The input Python code as a string.
+ :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
+ :return: A sanitized version of the input code, containing only relevant parts.
+ """
+ code = code_extract(code)
+ code_bytes = bytes(code, "utf8")
+ parser = Parser(Language(tree_sitter_python.language()))
+ tree = parser.parse(code_bytes)
+ class_names = set()
+ function_names = set()
+ variable_names = set()
+
+ root_node = tree.root_node
+ import_nodes = []
+ definition_nodes = []
+
+ for child in root_node.children:
+ if child.type in NodeType.IMPORT.value:
+ import_nodes.append(child)
+ elif child.type == NodeType.CLASS.value:
+ name = get_definition_name(child)
+ if not (name in class_names or name in variable_names or name in function_names):
+ definition_nodes.append((name, child))
+ class_names.add(name)
+ elif child.type == NodeType.FUNCTION.value:
+ name = get_definition_name(child)
+ if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
+ child
+ ):
+ definition_nodes.append((name, child))
+ function_names.add(get_definition_name(child))
+ elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
+ subchild = child.children[0]
+ name = get_definition_name(subchild)
+ if not (name in variable_names or name in function_names or name in class_names):
+ definition_nodes.append((name, subchild))
+ variable_names.add(name)
+
+ if entrypoint:
+ name2deps = get_deps(definition_nodes)
+ reacheable = get_function_dependency(entrypoint, name2deps)
+
+ sanitized_output = b""
+
+ for node in import_nodes:
+ sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
+
+ for pair in definition_nodes:
+ name, node = pair
+ if entrypoint and name not in reacheable:
+ continue
+ sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
+ return sanitized_output[:-1].decode("utf8")
diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py
index c922f2cb4..3b9533571 100644
--- a/metagpt/utils/token_counter.py
+++ b/metagpt/utils/token_counter.py
@@ -10,6 +10,7 @@ ref3: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/t
ref4: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
ref5: https://ai.google.dev/models/gemini
"""
+import anthropic
import tiktoken
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk
@@ -377,6 +378,10 @@ SPARK_TOKENS = {
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""
+ if "claude" in model:
+ vo = anthropic.Client()
+ num_tokens = vo.count_tokens(str(messages))
+ return num_tokens
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@@ -463,6 +468,10 @@ def count_output_tokens(string: str, model: str) -> int:
Returns:
int: The number of tokens in the text string.
"""
+ if "claude" in model:
+ vo = anthropic.Client()
+ num_tokens = vo.count_tokens(string)
+ return num_tokens
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
diff --git a/requirements.txt b/requirements.txt
index b4f3f563d..5344ffb8c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,9 +12,9 @@ typer==0.9.0
lancedb==0.4.0
loguru==0.6.0
meilisearch==0.21.0
-numpy>=1.24.3
-openai>=1.6.1
-openpyxl
+numpy~=1.26.4
+openai~=1.39.0
+openpyxl~=3.1.5
beautifulsoup4==4.12.3
pandas==2.1.1
pydantic>=2.5.3
@@ -35,6 +35,9 @@ anthropic==0.18.1
typing-inspect==0.8.0
libcst==1.0.1
qdrant-client==1.7.0
+grpcio~=1.67.0
+grpcio-tools~=1.62.3
+grpcio-status~=1.62.3
# pytest-mock==3.11.1 # test extras require
# open-interpreter==0.1.7; python_version>"3.9" # Conflict with openai 1.x
ta==0.10.2
@@ -72,7 +75,7 @@ qianfan~=0.4.4
dashscope~=1.19.3
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
-volcengine-python-sdk[ark]~=1.0.94
+volcengine-python-sdk[ark]~=1.0.94 # Solution for installation error in Windows: https://github.com/volcengine/volcengine-python-sdk/issues/5
# llama-index-vector-stores-elasticsearch~=0.2.5 # Used by `metagpt/memory/longterm_memory.py`
# llama-index-vector-stores-chroma~=0.1.10 # Used by `metagpt/memory/longterm_memory.py`
gymnasium==0.29.1
diff --git a/setup.py b/setup.py
index 8ae4a3e1e..a996a1eb7 100644
--- a/setup.py
+++ b/setup.py
@@ -61,8 +61,6 @@ extras_require["test"] = [
"azure-cognitiveservices-speech~=1.31.0",
"aioboto3~=12.4.0",
"gradio==3.0.0",
- "grpcio-status==1.48.2",
- "grpcio-tools==1.48.2",
"google-api-core==2.17.1",
"protobuf==3.19.6",
"pylint==3.0.3",
diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py
index b9c9e0f93..28d1d7008 100644
--- a/tests/metagpt/provider/test_bedrock_api.py
+++ b/tests/metagpt/provider/test_bedrock_api.py
@@ -3,7 +3,7 @@ import json
import pytest
from metagpt.provider.bedrock.utils import (
- NOT_SUUPORT_STREAM_MODELS,
+ NOT_SUPPORT_STREAM_MODELS,
SUPPORT_STREAM_MODELS,
)
from metagpt.provider.bedrock_api import BedrockLLM
@@ -14,7 +14,7 @@ from tests.metagpt.provider.req_resp_const import (
)
# all available model from bedrock
-models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
+models = SUPPORT_STREAM_MODELS | NOT_SUPPORT_STREAM_MODELS
messages = [{"role": "user", "content": "Hi!"}]
usage = {
"prompt_tokens": 1000000,
diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py
index af2e929e9..75cfa86d5 100644
--- a/tests/metagpt/provider/test_ollama_api.py
+++ b/tests/metagpt/provider/test_ollama_api.py
@@ -3,11 +3,11 @@
# @Desc : the unittest of ollama api
import json
-from typing import Any, Tuple
+from typing import Any, AsyncGenerator, Tuple
import pytest
-from metagpt.provider.ollama_api import OllamaLLM
+from metagpt.provider.ollama_api import OllamaLLM, OpenAIResponse
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
@@ -23,21 +23,19 @@ default_resp = {"message": {"role": "assistant", "content": resp_cont}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
if stream:
- class Iterator(object):
+ async def async_event_generator() -> AsyncGenerator[Any, None]:
events = [
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
]
+ for event in events:
+ yield OpenAIResponse(event, {})
- async def __aiter__(self):
- for event in self.events:
- yield event
-
- return Iterator(), None, None
+ return async_event_generator(), None, None
else:
raw_default_resp = default_resp.copy()
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
- return json.dumps(raw_default_resp).encode(), None, None
+ return OpenAIResponse(json.dumps(raw_default_resp).encode(), {}), None, None
@pytest.mark.asyncio
diff --git a/tests/metagpt/rag/factories/test_index.py b/tests/metagpt/rag/factories/test_index.py
index 9861e1242..e084eb6e7 100644
--- a/tests/metagpt/rag/factories/test_index.py
+++ b/tests/metagpt/rag/factories/test_index.py
@@ -7,7 +7,8 @@ from metagpt.rag.schema import (
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
- FAISSIndexConfig, MilvusIndexConfig,
+ FAISSIndexConfig,
+ MilvusIndexConfig,
)