diff --git a/.devcontainer/README.md b/.devcontainer/README.md index dd088aab1..be692c14d 100644 --- a/.devcontainer/README.md +++ b/.devcontainer/README.md @@ -1,39 +1,34 @@ -# Dev container +# Dev Container -This project includes a [dev container](https://containers.dev/), which lets you use a container as a full-featured dev environment. +This project includes a [Dev Container](https://containers.dev/), offering you a comprehensive and fully-featured development environment within a container. By leveraging the Dev Container configuration in this folder, you can seamlessly build and initiate MetaGPT locally. For detailed information, please refer to the main README in the home directory. -You can use the dev container configuration in this folder to build and start running MetaGPT locally! For more, refer to the main README under the home directory. -You can use it in [GitHub Codespaces](https://github.com/features/codespaces) or the [VS Code Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). +You can utilize this Dev Container in [GitHub Codespaces](https://github.com/features/codespaces) or with the [VS Code Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). ## GitHub Codespaces -Open in GitHub Codespaces +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/geekan/MetaGPT) -You may use the button above to open this repo in a Codespace +Click the button above to open this repository in a Codespace. For additional information, refer to the [GitHub documentation on creating a Codespace](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). -For more info, check out the [GitHub documentation](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). - ## VS Code Dev Containers -Open in Dev Containers +[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/geekan/MetaGPT) -Note: If you click this link you will open the main repo and not your local cloned repo, you can use this link and replace with your username and cloned repo name: -https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/geekan/MetaGPT +Note: Clicking the link above opens the main repository. To open your local cloned repository, replace the URL with your username and cloned repository's name: `https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com//` +If you have VS Code and Docker installed, use the button above to get started. This will prompt VS Code to install the Dev Containers extension if it's not already installed, clone the source code into a container volume, and set up a dev container for you. -If you already have VS Code and Docker installed, you can use the button above to get started. This will cause VS Code to automatically install the Dev Containers extension if needed, clone the source code into a container volume, and spin up a dev container for use. +Alternatively, follow these steps to open this repository in a container using the VS Code Dev Containers extension: -You can also follow these steps to open this repo in a container using the VS Code Dev Containers extension: +1. For first-time users of a development container, ensure your system meets the prerequisites (e.g., Docker installation) as outlined in the [getting started steps](https://aka.ms/vscode-remote/containers/getting-started). -1. If this is your first time using a development container, please ensure your system meets the pre-reqs (i.e. have Docker installed) in the [getting started steps](https://aka.ms/vscode-remote/containers/getting-started). - -2. Open a locally cloned copy of the code: - - - Fork and Clone this repository to your local filesystem. +2. To open a locally cloned copy of the code: + - Fork and clone this repository to your local file system. - Press F1 and select the **Dev Containers: Open Folder in Container...** command. - - Select the cloned copy of this folder, wait for the container to start, and try things out! + - Choose the cloned folder, wait for the container to initialize, and start exploring! -You can learn more in the [Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). +Learn more in the [VS Code Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). -## Tips and tricks +## Tips and Tricks -* If you are working with the same repository folder in a container and Windows, you'll want consistent line endings (otherwise you may see hundreds of changes in the SCM view). The `.gitattributes` file in the root of this repo will disable line ending conversion and should prevent this. See [tips and tricks](https://code.visualstudio.com/docs/devcontainers/tips-and-tricks#_resolving-git-line-ending-issues-in-containers-resulting-in-many-modified-files) for more info. -* If you'd like to review the contents of the image used in this dev container, you can check it out in the [devcontainers/images](https://github.com/devcontainers/images/tree/main/src/python) repo. +* When working with the same repository folder in both a container and on Windows, it's crucial to have consistent line endings to avoid numerous changes in the SCM view. The `.gitattributes` file in the root of this repository disables line ending conversion, helping to prevent this issue. For more information, see [resolving git line ending issues in containers](https://code.visualstudio.com/docs/devcontainers/tips-and-tricks#_resolving-git-line-ending-issues-in-containers-resulting-in-many-modified-files). + +* If you're curious about the contents of the image used in this Dev Container, you can review it in the [devcontainers/images](https://github.com/devcontainers/images/tree/main/src/python) repository. diff --git a/.devcontainer/postCreateCommand.sh b/.devcontainer/postCreateCommand.sh index 46788e306..3901193cd 100644 --- a/.devcontainer/postCreateCommand.sh +++ b/.devcontainer/postCreateCommand.sh @@ -4,4 +4,4 @@ sudo npm install -g @mermaid-js/mermaid-cli # Step 2: Ensure that Python 3.9+ is installed on your system. You can check this by using: python --version -pip install -e. \ No newline at end of file +pip install -e . \ No newline at end of file diff --git a/.dockerignore b/.dockerignore index 2968dd34d..8c09eaf73 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,7 +1,6 @@ workspace tmp build -workspace dist data geckodriver.log diff --git a/.gitattributes b/.gitattributes index 32555a806..7f1424434 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,29 @@ +# HTML code is incorrectly calculated into statistics, so ignore them *.html linguist-detectable=false +# Auto detect text files and perform LF normalization +* text=auto eol=lf + +# Ensure shell scripts use LF (Linux style) line endings on Windows +*.sh text eol=lf + +# Treat specific binary files as binary and prevent line ending conversion +*.png binary +*.jpg binary +*.gif binary +*.ico binary + +# Preserve original line endings for specific document files +*.doc text eol=crlf +*.docx text eol=crlf +*.pdf binary + +# Ensure source code and script files use LF line endings +*.py text eol=lf +*.js text eol=lf +*.html text eol=lf +*.css text eol=lf + +# Specify custom diff driver for specific file types +*.md diff=markdown +*.json diff=json diff --git a/.github/ISSUE_TEMPLATE/config.yaml b/.github/ISSUE_TEMPLATE/config.yaml new file mode 100644 index 000000000..622f76f1a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yaml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: "📑 Read online docs" + url: https://docs.deepwisdom.ai/ + about: Find the tutorials, use cases and blogs from the doc site. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/request_new_features.md b/.github/ISSUE_TEMPLATE/request_new_features.md new file mode 100644 index 000000000..c725cf6d2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/request_new_features.md @@ -0,0 +1,14 @@ +--- +name: "🤔 Request new features" +about: There are some ideas or demands want to discuss with the official and hope to be implemented in the future. +title: '' +labels: kind/features +assignees: '' +--- + +**Feature description** + + +**Your Feature** + + diff --git a/.github/ISSUE_TEMPLATE/show_me_the_bug.md b/.github/ISSUE_TEMPLATE/show_me_the_bug.md new file mode 100644 index 000000000..504a2bd12 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/show_me_the_bug.md @@ -0,0 +1,29 @@ +--- +name: "🪲 Show me the Bug" +about: Something happened when I use MetaGPT, I want to report it and hope to get help from the official and community. +title: '' +labels: kind/bug +assignees: '' +--- + +**Bug description** + + +**Bug solved method** + + + +**Environment information** + + +- LLM type and model name: +- System version: +- Python version: + + + +- packages version: +- installation method: + +**Screenshots or logs** + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..f5b280994 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,19 @@ + +**Features** + + + +- xx +- yy + +**Feature Docs** + + +**Influence** + + +**Result** + + +**Other** + \ No newline at end of file diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 000000000..ed4bbb144 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,30 @@ +name: Pre-commit checks + +on: + pull_request: + branches: + - '**' + push: + branches: + - '**' + +jobs: + pre-commit-check: + runs-on: ubuntu-latest + steps: + - name: Checkout Source Code + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: '3.9.17' + + - name: Install pre-commit + run: pip install pre-commit + + - name: Initialize pre-commit + run: pre-commit install + + - name: Run pre-commit hooks + run: pre-commit run --all-files \ No newline at end of file diff --git a/.gitignore b/.gitignore index 89e00e59a..93e24ba48 100644 --- a/.gitignore +++ b/.gitignore @@ -144,25 +144,19 @@ cython_debug/ allure-report allure-results -# idea +# idea / vscode / macos .idea .DS_Store .vscode -log.txt -docs/scripts/set_env.sh key.yaml -output.json data -data/output_add.json data.ms examples/nb/ .chroma *~$* workspace/* -*.mmd tmp -output.wav metagpt/roles/idea_agent.py .aider* *.bak diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1892a709..338f832ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_stages: [ commit ] # Install # 1. pip install pre-commit -# 2. pre-commit install(the first time you download the repo, it will be cached for future use) +# 2. pre-commit install repos: - repo: https://github.com/pycqa/isort rev: 5.11.5 diff --git a/Dockerfile b/Dockerfile index c6e22989b..9eeacbccb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ COPY . /app/metagpt WORKDIR /app/metagpt RUN mkdir workspace &&\ pip install --no-cache-dir -r requirements.txt &&\ - pip install -e. + pip install -e . # Running with an infinite loop using the tail command CMD ["sh", "-c", "tail -f /dev/null"] diff --git a/LICENSE b/LICENSE index 5b0c000cd..67460e101 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License -Copyright (c) Chenglin Wu +Copyright (c) 2023 Chenglin Wu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index b0faf85c7..a03c1eabf 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ # MetaGPT: The Multi-Agent Framework

Software Company Multi-Role Schematic (Gradually Implementing)

## News -- Dec 15: v0.5.0 is released! We introduce **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or exisiting human codebase. We also launch a whole collection of important features, including multilingual support (experimental), multiple programming languages support (experimental), incremental development (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism! +- Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduce **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or exisiting codebase. We also launch a whole collection of important features, including **multilingual support** (experimental), multiple **programming languages support** (experimental), **incremental development** (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism! ## Install @@ -50,19 +50,23 @@ # conda activate metagpt # Step 2: Clone the repository to your local machine for latest version, and install it. git clone https://github.com/geekan/MetaGPT.git cd MetaGPT -pip3 install -e. # or pip3 install metagpt # for stable version +pip3 install -e . # or pip3 install metagpt # for stable version -# Step 3: run metagpt cli -# setup your OPENAI_API_KEY in key.yaml copy from config.yaml -metagpt "Write a cli snake game" +# Step 3: setup your OPENAI_API_KEY, or make sure it existed in the env +mkdir ~/.metagpt +cp config/config.yaml ~/.metagpt/key.yaml +vim ~/.metagpt/key.yaml -# Step 4 [Optional]: If you want to save the artifacts like diagrams such as quadrant chart, system designs, sequence flow in the workspace, you can execute the step before Step 3. By default, the framework is compatible, and the entire process can be run completely without executing this step. +# Step 4: run metagpt cli +metagpt "Create a 2048 game in python" + +# Step 5 [Optional]: If you want to save the artifacts like diagrams such as quadrant chart, system designs, sequence flow in the workspace, you can execute the step before Step 3. By default, the framework is compatible, and the entire process can be run completely without executing this step. # If executing, ensure that NPM is installed on your system. Then install mermaid-js. (If you don't have npm in your computer, please go to the Node.js official website to install Node.js https://nodejs.org/ and then you will have npm tool in your computer.) npm --version sudo npm install -g @mermaid-js/mermaid-cli ``` -detail installation please refer to [cli_install](https://docs.deepwisdom.ai/guide/get_started/installation.html#install-stable-version) +detail installation please refer to [cli_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-stable-version) ### Docker installation > Note: In the Windows, you need to replace "/opt/metagpt" with a directory that Docker has permission to create, such as "D:\Users\x\metagpt" @@ -83,7 +87,7 @@ # Step 2: Run metagpt demo with container metagpt "Write a cli snake game" ``` -detail installation please refer to [docker_install](https://docs.deepwisdom.ai/guide/get_started/installation.html#install-with-docker) +detail installation please refer to [docker_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-with-docker) ### QuickStart & Demo Video - Try it on [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT) @@ -94,19 +98,19 @@ ### QuickStart & Demo Video ## Tutorial -- 🗒 [Online Document](https://docs.deepwisdom.ai/) -- 💻 [Usage](https://docs.deepwisdom.ai/guide/get_started/quickstart.html) -- 🔎 [What can MetaGPT do?](https://docs.deepwisdom.ai/guide/get_started/introduction.html) +- 🗒 [Online Document](https://docs.deepwisdom.ai/main/en/) +- 💻 [Usage](https://docs.deepwisdom.ai/main/en/guide/get_started/quickstart.html) +- 🔎 [What can MetaGPT do?](https://docs.deepwisdom.ai/main/en/guide/get_started/introduction.html) - 🛠 How to build your own agents? - - [MetaGPT Usage & Development Guide | Agent 101](https://docs.deepwisdom.ai/guide/tutorials/agent_101.html) - - [MetaGPT Usage & Development Guide | MultiAgent 101](https://docs.deepwisdom.ai/guide/tutorials/multi_agent_101.html) + - [MetaGPT Usage & Development Guide | Agent 101](https://docs.deepwisdom.ai/main/en/guide/tutorials/agent_101.html) + - [MetaGPT Usage & Development Guide | MultiAgent 101](https://docs.deepwisdom.ai/main/en/guide/tutorials/multi_agent_101.html) - 🧑‍💻 Contribution - [Develop Roadmap](docs/ROADMAP.md) - 🔖 Use Cases - - [Debate](https://docs.deepwisdom.ai/guide/use_cases/multi_agent/debate.html) - - [Researcher](https://docs.deepwisdom.ai/guide/use_cases/agent/researcher.html) - - [Recepit Assistant](https://docs.deepwisdom.ai/guide/use_cases/agent/receipt_assistant.html) -- ❓ [FAQs](https://docs.deepwisdom.ai/guide/faq.html) + - [Debate](https://docs.deepwisdom.ai/main/en/guide/use_cases/multi_agent/debate.html) + - [Researcher](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/researcher.html) + - [Recepit Assistant](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/receipt_assistant.html) +- ❓ [FAQs](https://docs.deepwisdom.ai/main/en/guide/faq.html) ## Support diff --git a/config/config.yaml b/config/config.yaml index d38240ae6..09f2895d1 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -5,10 +5,10 @@ # WORKSPACE_PATH: "Path for placing output files" #### if OpenAI -## The official OPENAI_API_BASE is https://api.openai.com/v1 -## If the official OPENAI_API_BASE is not available, we recommend using the [openai-forward](https://github.com/beidongjiedeguang/openai-forward). -## Or, you can configure OPENAI_PROXY to access official OPENAI_API_BASE. -OPENAI_API_BASE: "https://api.openai.com/v1" +## The official OPENAI_BASE_URL is https://api.openai.com/v1 +## If the official OPENAI_BASE_URL is not available, we recommend using the [openai-forward](https://github.com/beidongjiedeguang/openai-forward). +## Or, you can configure OPENAI_PROXY to access official OPENAI_BASE_URL. +OPENAI_BASE_URL: "https://api.openai.com/v1" #OPENAI_PROXY: "http://127.0.0.1:8118" #OPENAI_API_KEY: "YOUR_API_KEY" # set the value to sk-xxx if you host the openai interface for open llm model OPENAI_API_MODEL: "gpt-4-1106-preview" @@ -24,20 +24,22 @@ LLM_TYPE: OpenAI # Except for these three major models – OpenAI, MetaGPT LLM, #SPARK_URL : "ws://spark-api.xf-yun.com/v2.1/chat" #### if Anthropic -#Anthropic_API_KEY: "YOUR_API_KEY" +#ANTHROPIC_API_KEY: "YOUR_API_KEY" #### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb -#### You can use ENGINE or DEPLOYMENT mode #OPENAI_API_TYPE: "azure" -#OPENAI_API_BASE: "YOUR_AZURE_ENDPOINT" +#OPENAI_BASE_URL: "YOUR_AZURE_ENDPOINT" #OPENAI_API_KEY: "YOUR_AZURE_API_KEY" #OPENAI_API_VERSION: "YOUR_AZURE_API_VERSION" #DEPLOYMENT_NAME: "YOUR_DEPLOYMENT_NAME" -#DEPLOYMENT_ID: "YOUR_DEPLOYMENT_ID" #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/docs/.pylintrc b/docs/.pylintrc new file mode 100644 index 000000000..9e8488bc7 --- /dev/null +++ b/docs/.pylintrc @@ -0,0 +1,639 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist=pydantic + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +#ignore-patterns=^\.# +ignore-patterns=(.)*_test\.py,test_(.)*\.py + + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=120 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + v, + e, + d, + m, + df, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + expression-not-assigned, + pointless-statement + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work.. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/docs/FAQ-EN.md b/docs/FAQ-EN.md index af6868509..d4a9f6097 100644 --- a/docs/FAQ-EN.md +++ b/docs/FAQ-EN.md @@ -83,10 +83,10 @@ 1. PRD stuck / unable to access/ connection interrupted - 1. The official OPENAI_API_BASE address is `https://api.openai.com/v1` - 1. If the official OPENAI_API_BASE address is inaccessible in your environment (this can be verified with curl), it's recommended to configure using the reverse proxy OPENAI_API_BASE provided by libraries such as openai-forward. For instance, `OPENAI_API_BASE: "``https://api.openai-forward.com/v1``"` - 1. If the official OPENAI_API_BASE address is inaccessible in your environment (again, verifiable via curl), another option is to configure the OPENAI_PROXY parameter. This way, you can access the official OPENAI_API_BASE via a local proxy. If you don't need to access via a proxy, please do not enable this configuration; if accessing through a proxy is required, modify it to the correct proxy address. Note that when OPENAI_PROXY is enabled, don't set OPENAI_API_BASE. - 1. Note: OpenAI's default API design ends with a v1. An example of the correct configuration is: `OPENAI_API_BASE: "``https://api.openai.com/v1``"` + 1. The official OPENAI_BASE_URL address is `https://api.openai.com/v1` + 1. If the official OPENAI_BASE_URL address is inaccessible in your environment (this can be verified with curl), it's recommended to configure using the reverse proxy OPENAI_BASE_URL provided by libraries such as openai-forward. For instance, `OPENAI_BASE_URL: "``https://api.openai-forward.com/v1``"` + 1. If the official OPENAI_BASE_URL address is inaccessible in your environment (again, verifiable via curl), another option is to configure the OPENAI_PROXY parameter. This way, you can access the official OPENAI_BASE_URL via a local proxy. If you don't need to access via a proxy, please do not enable this configuration; if accessing through a proxy is required, modify it to the correct proxy address. Note that when OPENAI_PROXY is enabled, don't set OPENAI_BASE_URL. + 1. Note: OpenAI's default API design ends with a v1. An example of the correct configuration is: `OPENAI_BASE_URL: "``https://api.openai.com/v1``"` 1. Absolutely! How can I assist you today? diff --git a/docs/README_CN.md b/docs/README_CN.md index dd65c2a25..2855b5500 100644 --- a/docs/README_CN.md +++ b/docs/README_CN.md @@ -78,7 +78,7 @@ # 步骤2: 使用容器运行metagpt演示 metagpt "Write a cli snake game" ``` -详细的安装请安装 [docker_install](https://docs.deepwisdom.ai/zhcn/guide/get_started/installation.html#%E4%BD%BF%E7%94%A8docker%E5%AE%89%E8%A3%85) +详细的安装请安装 [docker_install](https://docs.deepwisdom.ai/main/zh/guide/get_started/installation.html#%E4%BD%BF%E7%94%A8docker%E5%AE%89%E8%A3%85) ### 快速开始的演示视频 - 在 [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT) 上进行体验 @@ -88,19 +88,19 @@ ### 快速开始的演示视频 https://github.com/geekan/MetaGPT/assets/34952977/34345016-5d13-489d-b9f9-b82ace413419 ## 教程 -- 🗒 [在线文档](https://docs.deepwisdom.ai/zhcn/) -- 💻 [如何使用](https://docs.deepwisdom.ai/zhcn/guide/get_started/quickstart.html) -- 🔎 [MetaGPT的能力及应用场景](https://docs.deepwisdom.ai/zhcn/guide/get_started/introduction.html) +- 🗒 [在线文档](https://docs.deepwisdom.ai/main/zh/) +- 💻 [如何使用](https://docs.deepwisdom.ai/main/zh/guide/get_started/quickstart.html) +- 🔎 [MetaGPT的能力及应用场景](https://docs.deepwisdom.ai/main/zh/guide/get_started/introduction.html) - 🛠 如何构建你自己的智能体? - - [MetaGPT的使用和开发教程 | 智能体入门](https://docs.deepwisdom.ai/zhcn/guide/tutorials/agent_101.html) - - [MetaGPT的使用和开发教程 | 多智能体入门](https://docs.deepwisdom.ai/zhcn/guide/tutorials/multi_agent_101.html) + - [MetaGPT的使用和开发教程 | 智能体入门](https://docs.deepwisdom.ai/main/zh/guide/tutorials/agent_101.html) + - [MetaGPT的使用和开发教程 | 多智能体入门](https://docs.deepwisdom.ai/main/zh/guide/tutorials/multi_agent_101.html) - 🧑‍💻 贡献 - [开发路线图](ROADMAP.md) - 🔖 示例 - - [辩论](https://docs.deepwisdom.ai/zhcn/guide/use_cases/multi_agent/debate.html) - - [调研员](https://docs.deepwisdom.ai/zhcn/guide/use_cases/agent/researcher.html) - - [票据助手](https://docs.deepwisdom.ai/zhcn/guide/use_cases/agent/receipt_assistant.html) -- ❓ [常见问题解答](https://docs.deepwisdom.ai/zhcn/guide/faq.html) + - [辩论](https://docs.deepwisdom.ai/main/zh/guide/use_cases/multi_agent/debate.html) + - [调研员](https://docs.deepwisdom.ai/main/zh/guide/use_cases/agent/researcher.html) + - [票据助手](https://docs.deepwisdom.ai/main/zh/guide/use_cases/agent/receipt_assistant.html) +- ❓ [常见问题解答](https://docs.deepwisdom.ai/main/zh/guide/faq.html) ## 支持 diff --git a/docs/README_JA.md b/docs/README_JA.md index 05f718635..8b2bf1fae 100644 --- a/docs/README_JA.md +++ b/docs/README_JA.md @@ -219,7 +219,7 @@ # 設定ファイルをコピーし、必要な修正を加える。 | 変数名 | config/key.yaml | env | | --------------------------------------- | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # 自分のキーに置き換える | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_API_BASE # オプション | OPENAI_API_BASE: "https:///v1" | export OPENAI_API_BASE="https:///v1" | +| OPENAI_BASE_URL # オプション | OPENAI_BASE_URL: "https:///v1" | export OPENAI_BASE_URL="https:///v1" | ## チュートリアル: スタートアップの開始 diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index 3cb03f374..d3f7ea408 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -21,7 +21,7 @@ ### Tasks 3. ~~Support human confirmation and modification during the process~~ (v0.3.0) New: Support human confirmation and modification with fewer constrainsts and a more user-friendly interface 4. Support process caching: Consider carefully whether to add server caching mechanism 5. ~~Resolve occasional failure to follow instruction under current prompts, causing code parsing errors, through stricter system prompts~~ (v0.4.0, with function call) - 6. Write documentation, describing the current features and usage at all levels (ongoing, continuously adding contents to [documentation site](https://docs.deepwisdom.ai/guide/get_started/introduction.html)) + 6. Write documentation, describing the current features and usage at all levels (ongoing, continuously adding contents to [documentation site](https://docs.deepwisdom.ai/main/en/guide/get_started/introduction.html)) 7. ~~Support Docker~~ 2. Features 1. Support a more standard and stable parser (need to analyze the format that the current LLM is better at) diff --git a/docs/tutorial/usage.md b/docs/tutorial/usage.md index fbe4a8311..a08d92a22 100644 --- a/docs/tutorial/usage.md +++ b/docs/tutorial/usage.md @@ -13,7 +13,7 @@ # Copy the configuration file and make the necessary modifications. | Variable Name | config/key.yaml | env | | ------------------------------------------ | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # Replace with your own key | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_API_BASE # Optional | OPENAI_API_BASE: "https:///v1" | export OPENAI_API_BASE="https:///v1" | +| OPENAI_BASE_URL # Optional | OPENAI_BASE_URL: "https:///v1" | export OPENAI_BASE_URL="https:///v1" | ### Initiating a startup diff --git a/docs/tutorial/usage_cn.md b/docs/tutorial/usage_cn.md index 1ef50d633..76a5d6b1b 100644 --- a/docs/tutorial/usage_cn.md +++ b/docs/tutorial/usage_cn.md @@ -13,7 +13,7 @@ # 复制配置文件并进行必要的修改 | 变量名 | config/key.yaml | env | | ----------------------------------- | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # 用您自己的密钥替换 | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_API_BASE # 可选 | OPENAI_API_BASE: "https:///v1" | export OPENAI_API_BASE="https:///v1" | +| OPENAI_BASE_URL # 可选 | OPENAI_BASE_URL: "https:///v1" | export OPENAI_BASE_URL="https:///v1" | ### 示例:启动一个创业公司 diff --git a/examples/agent_creator.py b/examples/agent_creator.py index 05417d24a..26af8a287 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -12,9 +12,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -with open(METAGPT_ROOT / "examples/build_customized_agent.py", "r") as f: - # use official example script to guide AgentCreator - MULTI_ACTION_AGENT_CODE_EXAMPLE = f.read() +EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py" +MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text() class CreateAgent(Action): @@ -50,8 +49,8 @@ class CreateAgent(Action): match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else "" CONFIG.workspace_path.mkdir(parents=True, exist_ok=True) - with open(CONFIG.workspace_path / "agent_created_agent.py", "w") as f: - f.write(code_text) + new_file = CONFIG.workspace_path / "agent_created_agent.py" + new_file.write_text(code_text) return code_text diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 677098399..76be1cc90 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -7,23 +7,20 @@ """ import asyncio -from metagpt.llm import LLM, Claude +from metagpt.llm import LLM from metagpt.logs import logger async def main(): llm = LLM() - claude = Claude() - logger.info(await claude.aask("你好,请进行自我介绍")) logger.info(await llm.aask("hello world")) logger.info(await llm.aask_batch(["hi", "write python hello world."])) hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}] logger.info(await llm.acompletion(hello_msg)) - logger.info(await llm.acompletion_batch([hello_msg])) - logger.info(await llm.acompletion_batch_text([hello_msg])) - logger.info(await llm.acompletion_text(hello_msg)) + + # streaming mode, much slower await llm.acompletion_text(hello_msg, stream=True) diff --git a/examples/search_kb.py b/examples/search_kb.py index 85d99854e..0afd7ad15 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -2,11 +2,10 @@ # -*- coding: utf-8 -*- """ @File : search_kb.py -@Modified By: mashenquan, 2023-8-9, fix-bug: cannot find metagpt module. +@Modified By: mashenquan, 2023-12-22. Delete useless codes. """ import asyncio -from metagpt.actions import Action from metagpt.const import DATA_PATH from metagpt.document_store import FaissStore from metagpt.logs import logger @@ -30,10 +29,9 @@ from metagpt.schema import Message async def search(): store = FaissStore(DATA_PATH / "example.json") role = Sales(profile="Sales", store=store) - role._watch({Action}) queries = [ - Message("Which facial cleanser is good for oily skin?", cause_by=Action), - Message("Is L'Oreal good to use?", cause_by=Action), + Message(content="Which facial cleanser is good for oily skin?"), + Message(content="Is L'Oreal good to use?"), ] for query in queries: logger.info(f"User: {query}") diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py index 97db1624a..adb5665cb 100644 --- a/examples/search_with_specific_engine.py +++ b/examples/search_with_specific_engine.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -@Modified By: mashenquan, 2023-8-9, fix-bug: cannot find metagpt module. """ import asyncio @@ -10,12 +9,13 @@ from metagpt.tools import SearchEngineType async def main(): + question = "What are the most interesting human facts?" # Serper API - # await Searcher(engine = SearchEngineType.SERPER_GOOGLE).run(["What are some good sun protection products?","What are some of the best beaches?"]) + # await Searcher(engine=SearchEngineType.SERPER_GOOGLE).run(question) # SerpAPI - # await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run("What are the best ski brands for skiers?") + # await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run(question) # Google API - await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run("What are the most interesting human facts?") + await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run(question) if __name__ == "__main__": diff --git a/examples/write_tutorial.py b/examples/write_tutorial.py index 0dba3cdb7..734afccc0 100644 --- a/examples/write_tutorial.py +++ b/examples/write_tutorial.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 # _*_ coding: utf-8 _*_ + """ @Time : 2023/9/4 21:40:57 @Author : Stitch-z @File : tutorial_assistant.py """ + import asyncio from metagpt.roles.tutorial_assistant import TutorialAssistant diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index 79ff94b3e..c34c72ed2 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -13,7 +13,6 @@ from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.debug_error import DebugError from metagpt.actions.design_api import WriteDesign from metagpt.actions.design_api_review import DesignReview -from metagpt.actions.design_filenames import DesignFilenames from metagpt.actions.project_management import AssignTasks, WriteTasks from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch from metagpt.actions.run_code import RunCode @@ -33,7 +32,6 @@ class ActionType(Enum): WRITE_PRD_REVIEW = WritePRDReview WRITE_DESIGN = WriteDesign DESIGN_REVIEW = DesignReview - DESIGN_FILENAMES = DesignFilenames WRTIE_CODE = WriteCode WRITE_CODE_REVIEW = WriteCodeReview WRITE_TEST = WriteTest diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 147975f18..cd2b5148f 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -4,49 +4,60 @@ @Time : 2023/5/11 14:43 @Author : alexanderwu @File : action.py -@Modified By: mashenquan, 2023/8/20. Add function return annotations. -@Modified By: mashenquan, 2023/9/8. Replace LLM with LLMFactory """ from __future__ import annotations -from abc import ABC -from typing import Optional +from typing import Any, Optional, Union -from tenacity import retry, stop_after_attempt, wait_random_exponential +from pydantic import BaseModel, Field -from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM -from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess -from metagpt.utils.common import OutputParser, general_after_log +from metagpt.schema import ( + CodeSummarizeContext, + CodingContext, + RunCodeContext, + TestingContext, +) + +action_subclass_registry = {} -class Action(ABC): - def __init__(self, name: str = "", context=None, llm: BaseGPTAPI = None): - self.name: str = name - if llm is None: - llm = LLM() - self.llm = llm - self.context = context - self.prefix = "" # aask*时会加上prefix,作为system_message - self.profile = "" # FIXME: USELESS - self.desc = "" # for skill manager - self.nodes = ... +class Action(BaseModel): + name: str = "" + llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = "" + prefix = "" # aask*时会加上prefix,作为system_message + desc = "" # for skill manager + # node: ActionNode = Field(default_factory=ActionNode, exclude=True) - # Output, useless - # self.content = "" - # self.instruct_content = None - # self.env = None + # builtin variables + builtin_class_name: str = "" - # def set_env(self, env): - # self.env = env + class Config: + arbitrary_types_allowed = True - def set_prefix(self, prefix, profile): + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + # deserialize child classes dynamically for inherited `action` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + action_subclass_registry[cls.__name__] = cls + + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(Action, self).dict(*args, **kwargs) + if "llm" in obj_dict: + obj_dict.pop("llm") + return obj_dict + + def set_prefix(self, prefix): """Set prefix for later usage""" self.prefix = prefix - self.profile = profile return self def __str__(self): @@ -62,33 +73,6 @@ class Action(ABC): system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - @retry( - wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(6), - after=general_after_log(logger), - ) - async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: - content = await self.llm.aask(prompt, system_msgs) - logger.debug(f"llm raw output:\n{content}") - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") - - else: # using markdown parser - parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(f"parsed_data:\n{parsed_data}") - instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) - - async def run(self, *args, **kwargs) -> str | ActionOutput | None: + async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 4ed8bf22e..8a0aaf146 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -4,24 +4,26 @@ @Time : 2023/12/11 18:45 @Author : alexanderwu @File : action_node.py + +NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, + we can use typing to extract the type of the node, but we cannot use built-in list to extract. """ import json -import re -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential -from metagpt.actions import ActionOutput from metagpt.llm import BaseGPTAPI from metagpt.logs import logger -from metagpt.utils.common import OutputParser -from metagpt.utils.custom_decoder import CustomDecoder +from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess +from metagpt.utils.common import OutputParser, general_after_log + +TAG = "CONTENT" + +LANGUAGE_CONSTRAINT = "Language: Please use the same language as the user input." +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." -CONSTRAINT = """ -- Language: Please use the same language as the user input. -- Format: output wrapped inside [CONTENT][/CONTENT] as format example, nothing else. -""" SIMPLE_TEMPLATE = """ ## context @@ -32,21 +34,21 @@ SIMPLE_TEMPLATE = """ ## format example {example} -## nodes: ": # " +## nodes: ": # " {instruction} ## constraint {constraint} ## action -Based on the 'context' content, fill in the {node_name} using the 'format example' format above." +Follow instructions of nodes, generate output and make sure it follows the format example. """ -def dict_to_markdown(d, prefix="###", postfix="\n"): +def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): markdown_str = "" for key, value in d.items(): - markdown_str += f"{prefix} {key}: {value}{postfix}" + markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" return markdown_str @@ -76,7 +78,7 @@ class ActionNode: key: str, expected_type: Type, instruction: str, - example: str, + example: Any, content: str = "", children: dict[str, "ActionNode"] = None, ): @@ -111,22 +113,22 @@ class ActionNode: obj.add_children(nodes) return obj - def get_children_mapping(self) -> Dict[str, Type]: + def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" return {k: (v.expected_type, ...) for k, v in self.children.items()} - def get_self_mapping(self) -> Dict[str, Type]: + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} - def get_mapping(self, mode="children") -> Dict[str, Type]: + def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]: """get key: type mapping under mode""" if mode == "children" or (mode == "auto" and self.children): return self.get_children_mapping() return self.get_self_mapping() @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" new_class = create_model(class_name, **mapping) @@ -148,29 +150,6 @@ class ActionNode: new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - @classmethod - def create_model_class_v2(cls, class_name: str, mapping: Dict[str, Type]): - """基于pydantic v2的模型动态生成,用来检验结果类型正确性,待验证""" - new_class = create_model(class_name, **mapping) - - @model_validator(mode="before") - def check_missing_fields(data): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(data.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return data - - @field_validator("*") - def check_name(v: Any, field: str) -> Any: - if field not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field}") - return v - - new_class.__model_validator_check_missing_fields = classmethod(check_missing_fields) - new_class.__field_validator_check_name = classmethod(check_name) - return new_class - def create_children_class(self): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" @@ -197,46 +176,46 @@ class ActionNode: return node_dict # 遍历子节点并递归调用 to_dict 方法 - for child_key, child_node in self.children.items(): + for _, child_node in self.children.items(): node_dict.update(child_node.to_dict(format_func)) return node_dict - def compile_to(self, i: Dict, to) -> str: - if to == "json": + def compile_to(self, i: Dict, schema, kv_sep) -> str: + if schema == "json": return json.dumps(i, indent=4) - elif to == "markdown": - return dict_to_markdown(i) + elif schema == "markdown": + return dict_to_markdown(i, kv_sep=kv_sep) else: return str(i) - def tagging(self, text, to, tag="") -> str: + def tagging(self, text, schema, tag="") -> str: if not tag: return text - if to == "json": + if schema == "json": return f"[{tag}]\n" + text + f"\n[/{tag}]" - else: + else: # markdown return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, to, mode, tag, format_func) -> str: + def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str: nodes = self.to_dict(format_func=format_func, mode=mode) - text = self.compile_to(nodes, to) - return self.tagging(text, to, tag) + text = self.compile_to(nodes, schema, kv_sep) + return self.tagging(text, schema, tag) - def compile_instruction(self, to="raw", mode="children", tag="") -> str: + def compile_instruction(self, schema="markdown", mode="children", tag="") -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ") - def compile_example(self, to="raw", mode="children", tag="") -> str: + def compile_example(self, schema="json", mode="children", tag="") -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n") - def compile(self, context, to="json", mode="children", template=SIMPLE_TEMPLATE) -> str: + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: """ mode: all/root/children mode="children": 编译所有子节点为一个统一模板,包括instruction与example @@ -245,48 +224,47 @@ class ActionNode: """ # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", - self.instruction = self.compile_instruction(to="markdown", mode=mode) - self.example = self.compile_example(to=to, tag="CONTENT", mode=mode) - node_name = "nodes" if template != SIMPLE_TEMPLATE else f'"{list(self.children.keys())[0]}" node' + # compile example暂时不支持markdown + self.instruction = self.compile_instruction(schema="markdown", mode=mode) + self.example = self.compile_example(schema=schema, tag=TAG, mode=mode) + # nodes = ", ".join(self.to_dict(mode=mode).keys()) + constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] + constraint = "\n".join(constraints) + prompt = template.format( context=context, example=self.example, instruction=self.instruction, - constraint=CONSTRAINT, - node_name=node_name, + constraint=constraint, ) return prompt - @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6)) + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) async def _aask_v1( self, prompt: str, output_class_name: str, output_data_mapping: dict, system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: + schema="markdown", # compatible to original format + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" content = await self.llm.aask(prompt, system_msgs) - logger.debug(content) - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" - matches = re.findall(pattern, content, re.DOTALL) - - for match in matches: - if match: - content = match - break - - parsed_data = CustomDecoder(strict=False).decode(content) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + if schema == "json": + parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key=f"[/{TAG}]") else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - logger.debug(parsed_data) + logger.debug(f"parsed_data:\n{parsed_data}") instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) + return content, instruct_content def get(self, key): return self.instruct_content.dict()[key] @@ -302,23 +280,22 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, to, mode): - prompt = self.compile(context=self.context, to=to, mode=mode) + async def simple_fill(self, schema, mode): + prompt = self.compile(context=self.context, schema=schema, mode=mode) mapping = self.get_mapping(mode) class_name = f"{self.key}_AN" - print(prompt) - output = await self._aask_v1(prompt, class_name, mapping, format=to) - self.content = output.content - self.instruct_content = output.instruct_content + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema) + self.content = content + self.instruct_content = scontent return self - async def fill(self, context, llm, to="json", mode="auto", strgy="simple"): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"): """Fill the node(s) with mode. :param context: Everything we should know when filling node. :param llm: Large Language Model with pre-defined system message. - :param to: json/markdown, determine example and output format. + :param schema: json/markdown, determine example and output format. - json: it's easy to open source LLM with json format - markdown: when generating code, markdown is always better :param mode: auto/children/root @@ -334,12 +311,12 @@ class ActionNode: self.set_context(context) if strgy == "simple": - return await self.simple_fill(to, mode) + return await self.simple_fill(schema, mode) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): - child = await i.simple_fill(to, mode) + child = await i.simple_fill(schema, mode) tmp.update(child.instruct_content.dict()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/action_output.py b/metagpt/actions/action_output.py index 87d1c31ff..6be8dac50 100644 --- a/metagpt/actions/action_output.py +++ b/metagpt/actions/action_output.py @@ -4,40 +4,15 @@ @Time : 2023/7/11 10:03 @Author : chengmaoyu @File : action_output -@Modified By: mashenquan, 2023/8/20. Allow 'instruct_content' to be blank. """ -from typing import Dict, Optional, Type - -from pydantic import BaseModel, create_model, root_validator, validator +from pydantic import BaseModel class ActionOutput: content: str - instruct_content: Optional[BaseModel] = None + instruct_content: BaseModel - def __init__(self, content: str, instruct_content: BaseModel = None): + def __init__(self, content: str, instruct_content: BaseModel): self.content = content self.instruct_content = instruct_content - - @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): - new_class = create_model(class_name, **mapping) - - @validator("*", allow_reuse=True) - def check_name(v, field): - if field.name not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field.name}") - return v - - @root_validator(pre=True, allow_reuse=True) - def check_missing_fields(values): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(values.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return values - - new_class.__validator_check_name = classmethod(check_name) - new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) - return new_class diff --git a/metagpt/actions/analyze_dep_libs.py b/metagpt/actions/analyze_dep_libs.py deleted file mode 100644 index 53d40200a..000000000 --- a/metagpt/actions/analyze_dep_libs.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 12:01 -@Author : alexanderwu -@File : analyze_dep_libs.py -""" - -from metagpt.actions import Action - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. - -For the user's prompt: - ---- -The API is: {prompt} ---- - -We decide the generated files are: {filepaths_string} - -Now that we have a file list, we need to understand the shared dependencies they have. -Please list and briefly describe the shared contents between the files we are generating, including exported variables, -data patterns, id names of all DOM elements that javascript functions will use, message names and function names. -Focus only on the names of shared dependencies, do not add any other explanations. -""" - - -class AnalyzeDepLibs(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = "Analyze the runtime dependencies of the program based on the context" - - async def run(self, requirement, filepaths_string): - # prompt = f"Below is the product requirement document (PRD):\n\n{prd}\n\n{PROMPT}" - prompt = PROMPT.format(prompt=requirement, filepaths_string=filepaths_string) - design_filenames = await self._aask(prompt) - return design_filenames diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 39f3bc1bc..9dc6862f9 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -10,11 +10,14 @@ """ import re +from pydantic import Field + from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult +from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -47,8 +50,9 @@ Now you should start rewriting the code: class DebugError(Action): - def __init__(self, name="DebugError", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "DebugError" + context: RunCodeContext = Field(default_factory=RunCodeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 5a5f52de7..055365421 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -11,6 +11,9 @@ """ import json from pathlib import Path +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE @@ -22,16 +25,13 @@ from metagpt.const import ( SYSTEM_DESIGN_FILE_REPO, SYSTEM_DESIGN_PDF_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.schema import Document, Documents +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.schema import Document, Documents, Message from metagpt.utils.file_repository import FileRepository - -# from metagpt.utils.get_template import get_template from metagpt.utils.mermaid import mermaid_to_file -# from typing import List - - NEW_REQ_TEMPLATE = """ ### Legacy Content {old_design} @@ -42,15 +42,16 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, think about the system design, and design the corresponding APIs, " - "data structures, library tables, processes, and paths. Please provide your design, feedback " - "clearly and in detail." - ) + name: str = "" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + desc: str = ( + "Based on the PRD, think about the system design, and design the corresponding APIs, " + "data structures, library tables, processes, and paths. Please provide your design, feedback " + "clearly and in detail." + ) - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) changed_prds = prds_file_repo.changed_files @@ -80,13 +81,13 @@ class WriteDesign(Action): # leaving room for global optimization in subsequent steps. return ActionOutput(content=changed_files.json(), instruct_content=changed_files) - async def _new_system_design(self, context, format=CONFIG.prompt_format): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + async def _new_system_design(self, context, schema=CONFIG.prompt_schema): + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) return node - async def _merge(self, prd_doc, system_design_doc, format=CONFIG.prompt_format): + async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) system_design_doc.content = node.instruct_content.json(ensure_ascii=False) return system_design_doc diff --git a/metagpt/actions/design_api_an.py b/metagpt/actions/design_api_an.py index 0a303cdd5..7d6802381 100644 --- a/metagpt/actions/design_api_an.py +++ b/metagpt/actions/design_api_an.py @@ -5,6 +5,8 @@ @Author : alexanderwu @File : design_api_an.py """ +from typing import List + from metagpt.actions.action_node import ActionNode from metagpt.logs import logger from metagpt.utils.mermaid import MMC1, MMC2 @@ -22,7 +24,7 @@ PROJECT_NAME = ActionNode( FILE_LIST = ActionNode( key="File list", - expected_type=list[str], + expected_type=List[str], instruction="Only need relative paths. ALWAYS write a main.py or app.py here", example=["main.py", "game.py"], ) diff --git a/metagpt/actions/design_filenames.py b/metagpt/actions/design_filenames.py deleted file mode 100644 index ffa171d7b..000000000 --- a/metagpt/actions/design_filenames.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 11:50 -@Author : alexanderwu -@File : design_filenames.py -""" -from metagpt.actions import Action -from metagpt.logs import logger - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. -When given their intentions, provide a complete and exhaustive list of file paths needed to write the program for the user. -Only list the file paths you will write and return them as a Python string list. -Do not add any other explanations, just return a Python string list.""" - - -class DesignFilenames(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, consider system design, and carry out the basic design of the corresponding " - "APIs, data structures, and database tables. Please give your design, feedback clearly and in detail." - ) - - async def run(self, prd): - prompt = f"The following is the Product Requirement Document (PRD):\n\n{prd}\n\n{PROMPT}" - design_filenames = await self._aask(prompt) - logger.debug(prompt) - logger.debug(design_filenames) - return design_filenames diff --git a/metagpt/actions/detail_mining.py b/metagpt/actions/detail_mining.py deleted file mode 100644 index 5afcf52c6..000000000 --- a/metagpt/actions/detail_mining.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/9/12 17:45 -@Author : fisherdeng -@File : detail_mining.py -""" -from metagpt.actions import Action, ActionOutput - -PROMPT_TEMPLATE = """ -##TOPIC -{topic} - -##RECORD -{record} - -##Format example -{format_example} ------ - -Task: Refer to the "##TOPIC" (discussion objectives) and "##RECORD" (discussion records) to further inquire about the details that interest you, within a word limit of 150 words. -Special Note 1: Your intention is solely to ask questions without endorsing or negating any individual's viewpoints. -Special Note 2: This output should only include the topic "##OUTPUT". Do not add, remove, or modify the topic. Begin the output with '##OUTPUT', followed by an immediate line break, and then proceed to provide the content in the specified format as outlined in the "##Format example" section. -Special Note 3: The output should be in the same language as the input. -""" -FORMAT_EXAMPLE = """ - -## - -##OUTPUT -...(Please provide the specific details you would like to inquire about here.) - -## - -## -""" -OUTPUT_MAPPING = { - "OUTPUT": (str, ...), -} - - -class DetailMining(Action): - """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.""" - - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) - - async def run(self, topic, record) -> ActionOutput: - prompt = PROMPT_TEMPLATE.format(topic=topic, record=record, format_example=FORMAT_EXAMPLE) - rsp = await self._aask_v1(prompt, "detail_mining", OUTPUT_MAPPING) - return rsp diff --git a/metagpt/actions/fix_bug.py b/metagpt/actions/fix_bug.py index 6bd550d3d..56b488218 100644 --- a/metagpt/actions/fix_bug.py +++ b/metagpt/actions/fix_bug.py @@ -10,5 +10,7 @@ from metagpt.actions import Action class FixBug(Action): """Fix bug action without any implementation details""" + name: str = "FixBug" + async def run(self, *args, **kwargs): raise NotImplementedError diff --git a/metagpt/actions/generate_questions.py b/metagpt/actions/generate_questions.py new file mode 100644 index 000000000..c38c463bc --- /dev/null +++ b/metagpt/actions/generate_questions.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/12 17:45 +@Author : fisherdeng +@File : generate_questions.py +""" +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="Task: Refer to the context to further inquire about the details that interest you, within a word limit" + " of 150 words. Please provide the specific details you would like to inquire about here", + example=["1. What ...", "2. How ...", "3. ..."], +) + + +class GenerateQuestions(Action): + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" + + async def run(self, context): + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 8d3445ae4..696dc9a89 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -9,31 +9,41 @@ """ import shutil from pathlib import Path +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG -from metagpt.const import DEFAULT_WORKSPACE_ROOT, DOCS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository class PrepareDocuments(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) + """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" + + name: str = "PrepareDocuments" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + + def _init_repo(self): + """Initialize the Git environment.""" + path = CONFIG.project_path + if not path: + name = CONFIG.project_name or FileRepository.new_filename() + path = Path(CONFIG.workspace_path) / name + + if path.exists() and not CONFIG.inc: + shutil.rmtree(path) + CONFIG.git_repo = GitRepository(local_path=path, auto_init=True) async def run(self, with_messages, **kwargs): - if not CONFIG.git_repo: - # Create and initialize the workspace folder, initialize the Git environment. - project_name = CONFIG.project_name or FileRepository.new_filename() - workdir = CONFIG.project_path - if not workdir and CONFIG.workspace_path: - workdir = Path(CONFIG.workspace_path) / project_name - workdir = Path(workdir or DEFAULT_WORKSPACE_ROOT / project_name) - if not CONFIG.inc and workdir.exists(): - shutil.rmtree(workdir) - CONFIG.git_repo = GitRepository() - CONFIG.git_repo.open(local_path=workdir, auto_init=True) + """Create and initialize the workspace folder, initialize the Git environment.""" + self._init_repo() # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content) diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py index b2704616e..7ed42d590 100644 --- a/metagpt/actions/prepare_interview.py +++ b/metagpt/actions/prepare_interview.py @@ -6,35 +6,18 @@ @File : prepare_interview.py """ from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode -PROMPT_TEMPLATE = """ -# Context -{context} - -## Format example ---- -Q1: question 1 here -References: - - point 1 - - point 2 - -Q2: question 2 here... ---- - ------ -Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="""Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; Requirement: Provide a list of questions for the interviewer to ask the interviewee, by reading the resume of the interviewee in the context. -Attention: Provide as markdown block as the format above, at least 10 questions. -""" - -# prepare for a interview +Attention: Provide as markdown block as the format above, at least 10 questions.""", + example=["1. What ...", "2. How ..."], +) class PrepareInterview(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - async def run(self, context): - prompt = PROMPT_TEMPLATE.format(context=context) - question_list = await self._aask_v1(prompt) - return question_list + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 780d87a03..095881e60 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -9,7 +9,11 @@ 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. """ + import json +from typing import Optional + +from pydantic import Field from metagpt.actions import ActionOutput from metagpt.actions.action import Action @@ -21,7 +25,9 @@ from metagpt.const import ( TASK_FILE_REPO, TASK_PDF_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository @@ -35,10 +41,11 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): - def __init__(self, name="CreateTasks", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "CreateTasks" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_system_designs = system_design_file_repo.changed_files @@ -85,16 +92,16 @@ class WriteTasks(Action): await self._save_pdf(task_doc=task_doc) return task_doc - async def _run_new_tasks(self, context, format=CONFIG.prompt_format): - node = await PM_NODE.fill(context, self.llm, format) + async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema): + node = await PM_NODE.fill(context, self.llm, schema) # prompt_template, format_example = get_template(templates, format) # prompt = prompt_template.format(context=context, format_example=format_example) # rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING, format=format) return node - async def _merge(self, system_design_doc, task_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document: context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content) - node = await PM_NODE.fill(context, self.llm, format) + node = await PM_NODE.fill(context, self.llm, schema) task_doc.content = node.instruct_content.json(ensure_ascii=False) return task_doc diff --git a/metagpt/actions/project_management_an.py b/metagpt/actions/project_management_an.py index 6208c1051..215a67202 100644 --- a/metagpt/actions/project_management_an.py +++ b/metagpt/actions/project_management_an.py @@ -5,26 +5,28 @@ @Author : alexanderwu @File : project_management_an.py """ +from typing import List + from metagpt.actions.action_node import ActionNode from metagpt.logs import logger REQUIRED_PYTHON_PACKAGES = ActionNode( key="Required Python packages", - expected_type=list[str], + expected_type=List[str], instruction="Provide required Python packages in requirements.txt format.", example=["flask==1.1.2", "bcrypt==3.2.0"], ) REQUIRED_OTHER_LANGUAGE_PACKAGES = ActionNode( key="Required Other language third-party packages", - expected_type=list[str], + expected_type=List[str], instruction="List down the required packages for languages other than Python.", example=["No third-party dependencies required"], ) LOGIC_ANALYSIS = ActionNode( key="Logic Analysis", - expected_type=list[list[str]], + expected_type=List[List[str]], instruction="Provide a list of files with the classes/methods/functions to be implemented, " "including dependency analysis and imports.", example=[ @@ -35,7 +37,7 @@ LOGIC_ANALYSIS = ActionNode( TASK_LIST = ActionNode( key="Task list", - expected_type=list[str], + expected_type=List[str], instruction="Break down the tasks into a list of filenames, prioritized by dependency order.", example=["game.py", "main.py"], ) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 1b9fd252f..bca9b337d 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -18,10 +18,13 @@ import subprocess from typing import Tuple +from pydantic import Field + from metagpt.actions.action import Action from metagpt.config import CONFIG +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult +from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.exceptions import handle_exception PROMPT_TEMPLATE = """ @@ -74,8 +77,9 @@ standard errors: class RunCode(Action): - def __init__(self, name="RunCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "RunCode" + context: RunCodeContext = Field(default_factory=RunCodeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @classmethod @handle_exception diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 5c7577e17..bc1319291 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -4,14 +4,19 @@ @Time : 2023/5/23 17:26 @Author : alexanderwu @File : search_google.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ +from typing import Optional + import pydantic +from pydantic import Field, root_validator from metagpt.actions import Action -from metagpt.config import CONFIG +from metagpt.config import CONFIG, Config +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message +from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements @@ -55,7 +60,6 @@ SEARCH_AND_SUMMARIZE_PROMPT = """ """ - SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. - The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. @@ -102,16 +106,31 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): - def __init__(self, name="", context=None, llm=None, engine=None, search_func=None): - self.engine = engine or CONFIG.search_engine + name: str = "" + content: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + config: None = Field(default_factory=Config) + engine: Optional[SearchEngineType] = CONFIG.search_engine + search_func: Optional[str] = None + search_engine: SearchEngine = None + result = "" + + @root_validator + def validate_engine_and_run_func(cls, values): + engine = values.get("engine") + search_func = values.get("search_func") + config = Config() + + if engine is None: + engine = config.search_engine try: - self.search_engine = SearchEngine(self.engine, run_func=search_func) + search_engine = SearchEngine(engine=engine, run_func=search_func) except pydantic.ValidationError: - self.search_engine = None + search_engine = None - self.result = "" - super().__init__(name, context, llm) + values["search_engine"] = search_engine + return values async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: if self.search_engine is None: @@ -130,8 +149,7 @@ class SearchAndSummarize(Action): system_prompt = [system_text] prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( - # PREFIX = self.prefix, - ROLE=self.profile, + ROLE=self.prefix, CONTEXT=rsp, QUERY_HISTORY="\n".join([str(i) for i in context[:-1]]), QUERY=str(context[-1]), diff --git a/metagpt/actions/skill_action.py b/metagpt/actions/skill_action.py index f629cfcbf..c95a83cbb 100644 --- a/metagpt/actions/skill_action.py +++ b/metagpt/actions/skill_action.py @@ -12,6 +12,7 @@ import ast import importlib import traceback from copy import deepcopy +from typing import Dict, Optional from metagpt.actions import Action, ActionOutput from metagpt.learn.skill_loader import Skill @@ -19,12 +20,10 @@ from metagpt.logs import logger class ArgumentsParingAction(Action): - def __init__(self, last_talk: str, skill: Skill, context=None, llm=None, **kwargs): - super(ArgumentsParingAction, self).__init__(name="", context=context, llm=llm) - self.skill = skill - self.ask = last_talk - self.rsp = None - self.args = None + skill: Skill + ask: str + rsp: Optional[ActionOutput] + args: Optional[Dict] @property def prompt(self): @@ -70,25 +69,23 @@ class ArgumentsParingAction(Action): class SkillAction(Action): - def __init__(self, skill: Skill, args: dict, context=None, llm=None, **kwargs): - super(SkillAction, self).__init__(name="", context=context, llm=llm) - self._skill = skill - self._args = args - self.rsp = None + skill: Skill + args: Dict + rsp: str = "" async def run(self, *args, **kwargs) -> str | ActionOutput | None: """Run action""" options = deepcopy(kwargs) - if self._args: - for k in self._args.keys(): + if self.args: + for k in self.args.keys(): if k in options: options.pop(k) try: - self.rsp = await self.find_and_call_function(self._skill.name, args=self._args, **options) + self.rsp = await self.find_and_call_function(self.skill.name, args=self.args, **options) except Exception as e: logger.exception(f"{e}, traceback:{traceback.format_exc()}") self.rsp = f"Error: {e}" - return ActionOutput(content=self.rsp, instruct_content=self._skill.json()) + return ActionOutput(content=self.rsp, instruct_content=self.skill.json()) @staticmethod async def find_and_call_function(function_name, args, **kwargs): diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index f8d8d2b47..0aec15937 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -7,12 +7,15 @@ """ from pathlib import Path +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger +from metagpt.schema import CodeSummarizeContext from metagpt.utils.file_repository import FileRepository PROMPT_TEMPLATE = """ @@ -89,8 +92,9 @@ flowchart TB class SummarizeCode(Action): - def __init__(self, name="SummarizeCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "SummarizeCode" + context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 5960e2621..4d0690e0f 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -14,8 +14,10 @@ 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. """ + import json +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action @@ -27,7 +29,9 @@ from metagpt.const import ( TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -84,8 +88,9 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): - def __init__(self, name="WriteCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteCode" + context: Document = Field(default_factory=Document) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: @@ -126,7 +131,9 @@ class WriteCode(Action): logger.info(f"Writing {coding_context.filename}..") code = await self.write_code(prompt) if not coding_context.code_doc: - coding_context.code_doc = Document(filename=coding_context.filename, root_path=CONFIG.src_workspace) + # avoid root_path pydantic ValidationError if use WriteCode alone + root_path = CONFIG.src_workspace if CONFIG.src_workspace else "" + coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/actions/write_code_an_draft.py b/metagpt/actions/write_code_an_draft.py new file mode 100644 index 000000000..968c8924b --- /dev/null +++ b/metagpt/actions/write_code_an_draft.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +import asyncio +from typing import List + +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced reviewer and critically assess the given output. Provide specific and" + " constructive feedback, highlighting areas for improvement and suggesting changes.", + example=[ + "The logic in the function `calculate_total` seems flawed. Shouldn't it consider the discount rate as well?", + "The TODO function is not implemented yet? Should we implement it before commit?", + ], +) + +LGTM = ActionNode( + key="LGTM", + expected_type=str, + instruction="LGTM/LBTM. If the code is fully implemented, " + "give a LGTM (Looks Good To Me), otherwise provide a LBTM (Looks Bad To Me).", + example="LBTM", +) + +ACTIONS = ActionNode( + key="Actions", + expected_type=str, + instruction="Based on the code review outcome, suggest actionable steps. This can include code changes, " + "refactoring suggestions, or any follow-up tasks.", + example="""1. Refactor the `process_data` method to improve readability and efficiency. +2. Cover edge cases in the `validate_user` function. +3. Implement a the TODO in the `calculate_total` function. +4. Fix the `handle_events` method to update the game state only if a move is successful. + ```python + def handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return False + if event.type == pygame.KEYDOWN: + moved = False + if event.key == pygame.K_UP: + moved = self.game.move('UP') + elif event.key == pygame.K_DOWN: + moved = self.game.move('DOWN') + elif event.key == pygame.K_LEFT: + moved = self.game.move('LEFT') + elif event.key == pygame.K_RIGHT: + moved = self.game.move('RIGHT') + if moved: + # Update the game state only if a move was successful + self.render() + return True + ``` +""", +) + +WRITE_DRAFT = ActionNode( + key="WriteDraft", + expected_type=str, + instruction="Could you write draft code for move function in order to implement it?", + example="Draft: ...", +) + + +WRITE_MOVE_FUNCTION = ActionNode( + key="WriteFunction", + expected_type=str, + instruction="write code for the function not implemented.", + example=""" +```Code +... +``` +""", +) + + +REWRITE_CODE = ActionNode( + key="RewriteCode", + expected_type=str, + instruction="""rewrite code based on the Review and Actions""", + example=""" +```python +## example.py +def calculate_total(price, quantity): + total = price * quantity +``` +""", +) + + +CODE_REVIEW_CONTEXT = """ +# System +Role: You are a professional software engineer, and your main task is to review and revise the code. You need to ensure that the code conforms to the google-style standards, is elegantly designed and modularized, easy to read and maintain. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. + +# Context +## System Design +{"Implementation approach": "我们将使用HTML、CSS和JavaScript来实现这个单机的响应式2048游戏。为了确保游戏性能流畅和响应式设计,我们会选择使用Vue.js框架,因为它易于上手且适合构建交互式界面。我们还将使用localStorage来记录玩家的最高分。", "File list": ["index.html", "styles.css", "main.js", "game.js", "storage.js"], "Data structures and interfaces": "classDiagram\ + class Game {\ + -board Array\ + -score Number\ + -bestScore Number\ + +constructor()\ + +startGame()\ + +move(direction: String)\ + +getBoard() Array\ + +getScore() Number\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Storage {\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Main {\ + +init()\ + +bindEvents()\ + }\ + Game --> Storage : uses\ + Main --> Game : uses", "Program call flow": "sequenceDiagram\ + participant M as Main\ + participant G as Game\ + participant S as Storage\ + M->>G: init()\ + G->>S: getBestScore()\ + S-->>G: return bestScore\ + M->>G: bindEvents()\ + M->>G: startGame()\ + loop Game Loop\ + M->>G: move(direction)\ + G->>S: setBestScore(score)\ + S-->>G: return\ + end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Tasks +{"Required Python packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Code Files +----- index.html + + + + + + 2048游戏 + + + + +
+

2048

+
+
+
分数
+
{{ score }}
+
+
+
最高分
+
{{ bestScore }}
+
+
+
+
+
+ {{ cell !== 0 ? cell : \'\' }} +
+
+
+ +
+ + + + + + + + +----- styles.css +/* styles.css */ +body, html { + margin: 0; + padding: 0; + font-family: \'Arial\', sans-serif; +} + +#app { + text-align: center; + font-size: 18px; + color: #776e65; +} + +h1 { + color: #776e65; + font-size: 72px; + font-weight: bold; + margin: 20px 0; +} + +.scores-container { + display: flex; + justify-content: center; + margin-bottom: 20px; +} + +.score-container, .best-container { + background: #bbada0; + padding: 10px; + border-radius: 5px; + margin: 0 10px; + min-width: 100px; + text-align: center; +} + +.score-header, .best-header { + color: #eee4da; + font-size: 18px; + margin-bottom: 5px; +} + +.game-container { + max-width: 500px; + margin: 0 auto 20px; + background: #bbada0; + padding: 15px; + border-radius: 10px; + position: relative; +} + +.grid-row { + display: flex; +} + +.grid-cell { + background: #cdc1b4; + width: 100px; + height: 100px; + margin: 5px; + display: flex; + justify-content: center; + align-items: center; + font-size: 35px; + font-weight: bold; + color: #776e65; + border-radius: 3px; +} + +/* Dynamic classes for different number cells */ +.number-cell-2 { + background: #eee4da; +} + +.number-cell-4 { + background: #ede0c8; +} + +.number-cell-8 { + background: #f2b179; + color: #f9f6f2; +} + +.number-cell-16 { + background: #f59563; + color: #f9f6f2; +} + +.number-cell-32 { + background: #f67c5f; + color: #f9f6f2; +} + +.number-cell-64 { + background: #f65e3b; + color: #f9f6f2; +} + +.number-cell-128 { + background: #edcf72; + color: #f9f6f2; +} + +.number-cell-256 { + background: #edcc61; + color: #f9f6f2; +} + +.number-cell-512 { + background: #edc850; + color: #f9f6f2; +} + +.number-cell-1024 { + background: #edc53f; + color: #f9f6f2; +} + +.number-cell-2048 { + background: #edc22e; + color: #f9f6f2; +} + +/* Larger numbers need smaller font sizes */ +.number-cell-1024, .number-cell-2048 { + font-size: 30px; +} + +button { + background-color: #8f7a66; + color: #f9f6f2; + border: none; + border-radius: 3px; + padding: 10px 20px; + font-size: 18px; + cursor: pointer; + outline: none; +} + +button:hover { + background-color: #9f8b76; +} + +----- storage.js +## storage.js +class Storage { + // 获取最高分 + getBestScore() { + // 尝试从localStorage中获取最高分,如果不存在则默认为0 + const bestScore = localStorage.getItem(\'bestScore\'); + return bestScore ? Number(bestScore) : 0; + } + + // 设置最高分 + setBestScore(score) { + // 将最高分设置到localStorage中 + localStorage.setItem(\'bestScore\', score.toString()); + } +} + + + +## Code to be Reviewed: game.js +```Code +## game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SMALLEST_CONTEXT = """ +## Code to be Reviewed: game.js +```Code +// game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SAMPLE = """ +## Code Review: game.js +1. The code partially implements the requirements. The `Game` class is missing the full implementation of the `move` method, which is crucial for the game\'s functionality. +2. The code logic is not completely correct. The `move` method is not implemented, which means the game cannot process player moves. +3. The existing code follows the "Data structures and interfaces" in terms of class structure but lacks full method implementations. +4. Not all functions are implemented. The `move` method is incomplete and does not handle the logic for moving and merging tiles. +5. All necessary pre-dependencies seem to be imported since the code does not indicate the need for additional imports. +6. The methods from other files (such as `Storage`) are not being used in the provided code snippet, but the class structure suggests that they will be used correctly. + +## Actions +1. Implement the `move` method to handle tile movements and merging. This is a complex task that requires careful consideration of the game\'s rules and logic. Here is a simplified version of how one might begin to implement the `move` method: + ```javascript + move(direction) { + // Simplified logic for moving tiles up + if (direction === \'up\') { + for (let col = 0; col < 4; col++) { + let tiles = this.board.map(row => row[col]).filter(val => val !== 0); + let merged = []; + for (let i = 0; i < tiles.length; i++) { + if (tiles[i] === tiles[i + 1]) { + tiles[i] *= 2; + this.score += tiles[i]; + tiles[i + 1] = 0; + merged.push(i); + } + } + tiles = tiles.filter(val => val !== 0); + while (tiles.length < 4) { + tiles.push(0); + } + for (let row = 0; row < 4; row++) { + this.board[row][col] = tiles[row]; + } + } + } + // Additional logic needed for \'down\', \'left\', \'right\' + // ... + this.addRandomTile(); + } + ``` +2. Integrate the `Storage` class methods to handle the best score. This means updating the `startGame` and `setBestScore` methods to use `Storage` for retrieving and setting the best score: + ```javascript + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = new Storage().getBestScore(); // Retrieve the best score from storage + this.addRandomTile(); + this.addRandomTile(); + } + + setBestScore(score) { + if (score > this.bestScore) { + this.bestScore = score; + new Storage().setBestScore(score); // Set the new best score in storage + } + } + ``` + +## Code Review Result +LBTM + +``` +""" + + +WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM, ACTIONS]) +WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_MOVE_FUNCTION]) + + +CR_FOR_MOVE_FUNCTION_BY_3 = """ +The move function implementation provided appears to be well-structured and follows a clear logic for moving and merging tiles in the specified direction. However, there are a few potential improvements that could be made to enhance the code: + +1. Encapsulation: The logic for moving and merging tiles could be encapsulated into smaller, reusable functions to improve readability and maintainability. + +2. Magic Numbers: There are some magic numbers (e.g., 4, 3) used in the loops that could be replaced with named constants for improved readability and easier maintenance. + +3. Comments: Adding comments to explain the logic and purpose of each section of the code can improve understanding for future developers who may need to work on or maintain the code. + +4. Error Handling: It's important to consider error handling for unexpected input or edge cases to ensure the function behaves as expected in all scenarios. + +Overall, the code could benefit from refactoring to improve readability, maintainability, and extensibility. If you would like, I can provide a refactored version of the move function that addresses these considerations. +""" + + +class WriteCodeAN(Action): + """Write a code review for the context.""" + + async def run(self, context): + self.llm.system_prompt = "You are an outstanding engineer and can implement any code" + return await WRITE_MOVE_FUNCTION.fill(context=context, llm=self.llm, schema="json") + # return await WRITE_CODE_NODE.fill(context=context, llm=self.llm, schema="markdown") + + +async def main(): + await WriteCodeAN().run(CODE_REVIEW_SMALLEST_CONTEXT) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 365c87063..1eba672a5 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -8,12 +8,15 @@ WriteCode object, rather than passing them in when calling the run function. """ +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action from metagpt.config import CONFIG +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser @@ -32,7 +35,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc ``` """ - EXAMPLE_AND_INSTRUCTION = """ {format_example} @@ -119,8 +121,9 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): - def __init__(self, name="WriteCodeReview", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteCodeReview" + context: CodingContext = Field(default_factory=CodingContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): @@ -158,7 +161,8 @@ class WriteCodeReview(Action): format_example=format_example, ) logger.info( - f"Code review and rewrite {self.context.code_doc.filename}: {i+1}/{k} | {len(iterative_code)=}, {len(self.context.code_doc.content)=}" + f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, " + f"{len(self.context.code_doc.content)=}" ) result, rewrited_code = await self.write_code_review_and_rewrite( context_prompt, cr_prompt, self.context.code_doc.filename diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index bb0cf8fb9..1223e5486 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -10,10 +10,14 @@ 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. """ + from __future__ import annotations import json from pathlib import Path +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode @@ -32,17 +36,14 @@ from metagpt.const import ( PRDS_FILE_REPO, REQUIREMENT_FILENAME, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository - -# from metagpt.utils.get_template import get_template from metagpt.utils.mermaid import mermaid_to_file -# from typing import List - - CONTEXT_TEMPLATE = """ ### Project Name {project_name} @@ -64,15 +65,16 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "" + content: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) - async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message: + async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are # related to the PRD. If they are related, rewrite the PRD. docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME) - if await self._is_bugfix(requirement_doc.content): + if requirement_doc and await self._is_bugfix(requirement_doc.content): await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content) await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="") bug_fix = BugFixContext(filename=BUGFIX_FILENAME) @@ -111,7 +113,7 @@ class WritePRD(Action): # optimization in subsequent steps. return ActionOutput(content=change_files.json(), instruct_content=change_files) - async def _run_new_requirement(self, requirements, format=CONFIG.prompt_format) -> ActionOutput: + async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput: # sas = SearchAndSummarize() # # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) # rsp = "" @@ -121,7 +123,7 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema) await self._rename_workspace(node) return node @@ -130,18 +132,20 @@ class WritePRD(Action): node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) return node.get("is_relative") == "YES" - async def _merge(self, new_requirement_doc, prd_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document: if not CONFIG.project_name: CONFIG.project_name = Path(CONFIG.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content) - node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema) prd_doc.content = node.instruct_content.json(ensure_ascii=False) await self._rename_workspace(node) return prd_doc async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None: if not prd_doc: - prd = await self._run_new_requirement(requirements=[requirement_doc.content], *args, **kwargs) + prd = await self._run_new_requirement( + requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs + ) new_prd_doc = Document( root_path=PRDS_FILE_REPO, filename=FileRepository.new_filename() + ".json", @@ -182,7 +186,7 @@ class WritePRD(Action): return if not CONFIG.project_name: - if isinstance(prd, ActionOutput) or isinstance(prd, ActionNode): + if isinstance(prd, (ActionOutput, ActionNode)): ws_name = prd.instruct_content.dict()["Project Name"] else: ws_name = CodeParser.parse_str(block="Project Name", text=prd) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index d96c0aeac..d58d72f64 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -5,6 +5,7 @@ @Author : alexanderwu @File : write_prd_an.py """ +from typing import List from metagpt.actions.action_node import ActionNode from metagpt.logs import logger @@ -26,8 +27,8 @@ PROGRAMMING_LANGUAGE = ActionNode( ORIGINAL_REQUIREMENTS = ActionNode( key="Original Requirements", expected_type=str, - instruction="Place the polished, complete original requirements here.", - example="The game should have a leaderboard and multiple difficulty levels.", + instruction="Place the original user's requirements here.", + example="Create a 2048 game", ) PROJECT_NAME = ActionNode( @@ -39,26 +40,33 @@ PROJECT_NAME = ActionNode( PRODUCT_GOALS = ActionNode( key="Product Goals", - expected_type=list[str], + expected_type=List[str], instruction="Provide up to three clear, orthogonal product goals.", - example=["Create an engaging user experience", "Ensure high performance", "Provide customizable features"], + example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"], ) USER_STORIES = ActionNode( key="User Stories", - expected_type=list[str], - instruction="Provide up to five scenario-based user stories.", + expected_type=List[str], + instruction="Provide up to 3 to 5 scenario-based user stories.", example=[ - "As a user, I want to be able to choose difficulty levels", + "As a player, I want to be able to choose difficulty levels", "As a player, I want to see my score after each game", + "As a player, I want to get restart button when I lose", + "As a player, I want to see beautiful UI that make me feel good", + "As a player, I want to play game via mobile phone", ], ) COMPETITIVE_ANALYSIS = ActionNode( key="Competitive Analysis", - expected_type=list[str], - instruction="Provide analyses for up to seven competitive products.", - example=["Python Snake Game: Simple interface, lacks advanced features"], + expected_type=List[str], + instruction="Provide 5 to 7 competitive products.", + example=[ + "2048 Game A: Simple interface, lacks responsive features", + "play2048.co: Beautiful and responsive UI with my best score shown", + "2048game.com: Responsive UI with my best score shown, but many ads", + ], ) COMPETITIVE_QUADRANT_CHART = ActionNode( @@ -86,14 +94,14 @@ REQUIREMENT_ANALYSIS = ActionNode( key="Requirement Analysis", expected_type=str, instruction="Provide a detailed analysis of the requirements.", - example="The product should be user-friendly.", + example="", ) REQUIREMENT_POOL = ActionNode( key="Requirement Pool", - expected_type=list[list[str]], - instruction="List down the requirements with their priority (P0, P1, P2).", - example=[["P0", "..."], ["P1", "..."]], + expected_type=List[List[str]], + instruction="List down the top-5 requirements with their priority (P0, P1, P2).", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], ) UI_DESIGN_DRAFT = ActionNode( @@ -107,7 +115,7 @@ ANYTHING_UNCLEAR = ActionNode( key="Anything UNCLEAR", expected_type=str, instruction="Mention any aspects of the project that are unclear and try to clarify them.", - example="...", + example="", ) ISSUE_TYPE = ActionNode( diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 5ff9624c5..6ed73b6a2 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -5,20 +5,28 @@ @Author : alexanderwu @File : write_prd_review.py """ + +from typing import Optional + +from pydantic import Field + from metagpt.actions.action import Action +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI class WritePRDReview(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.prd = None - self.desc = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" - self.prd_review_prompt_template = """ - Given the following Product Requirement Document (PRD): - {prd} + name: str = "" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + prd: Optional[str] = None + desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" + prd_review_prompt_template: str = """ +Given the following Product Requirement Document (PRD): +{prd} - As a project manager, please review it and provide your feedback and suggestions. - """ +As a project manager, please review it and provide your feedback and suggestions. +""" async def run(self, prd): self.prd = prd diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py new file mode 100644 index 000000000..8a4856317 --- /dev/null +++ b/metagpt/actions/write_review.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +from typing import List + +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced Reviewer and review the given output. Ask a series of critical questions, " + "concisely and clearly, to help the writer improve their work.", + example=[ + "This is a good PRD, but I think it can be improved by adding more details.", + ], +) + +LGTM = ActionNode( + key="LGTM", + expected_type=str, + instruction="LGTM/LBTM. If the output is good enough, give a LGTM (Looks Good To Me) to the writer, " + "else LBTM (Looks Bad To Me).", + example="LGTM", +) + +WRITE_REVIEW_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM]) + + +class WriteReview(Action): + """Write a review for the given context.""" + + async def run(self, context): + return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 9dd967788..9eb0bdbb6 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -7,10 +7,17 @@ @Modified By: mashenquan, 2023-11-27. Following the think-act principle, solidify the task parameters when creating the WriteTest object, rather than passing them in when calling the run function. """ + +from typing import Optional + +from pydantic import Field + from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document, TestingContext from metagpt.utils.common import CodeParser @@ -36,8 +43,9 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): - def __init__(self, name="WriteTest", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteTest" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/config.py b/metagpt/config.py index fae2622db..3c773d780 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -9,7 +9,9 @@ Provide configuration, singleton import datetime import json import os +import warnings from copy import deepcopy +from enum import Enum from pathlib import Path from typing import Any from uuid import uuid4 @@ -19,6 +21,7 @@ import yaml from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS from metagpt.logs import logger from metagpt.tools import SearchEngineType, WebBrowserEngineType +from metagpt.utils.common import require_python_version from metagpt.utils.cost_manager import CostManager from metagpt.utils.singleton import Singleton @@ -35,6 +38,18 @@ class NotConfiguredException(Exception): super().__init__(self.message) +class LLMProviderEnum(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + GEMINI = "gemini" + METAGPT = "metagpt" + AZURE_OPENAI = "azure_openai" + + class Config(metaclass=Singleton): """ Regular usage method: @@ -49,47 +64,63 @@ class Config(metaclass=Singleton): default_yaml_file = METAGPT_ROOT / "config/config.yaml" def __init__(self, yaml_file=default_yaml_file, cost_data=""): + global_options = OPTIONS.get() + # cli paras + self.project_path = "" + self.project_name = "" + self.inc = False + self.reqa_file = "" + self.max_auto_summarize_code = 0 + self._init_with_config_files_and_env(yaml_file) # The agent needs to be billed per user, so billing information cannot be destroyed when the session ends. self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager() - logger.info("Config loading done.") self._update() - logger.info(f"OpenAI API Model: {self.openai_api_model}") + global_options.update(OPTIONS.get()) + logger.debug("Config loading done.") + + def get_default_llm_provider_enum(self) -> LLMProviderEnum: + for k, v in [ + (self.openai_api_key, LLMProviderEnum.OPENAI), + (self.anthropic_api_key, LLMProviderEnum.ANTHROPIC), + (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), + (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), + (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), + (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key + ]: + if self._is_valid_llm_key(k): + # logger.debug(f"Use LLMProvider: {v.value}") + if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): + warnings.warn("Use Gemini requires Python >= 3.10") + if self.openai_api_key and self.openai_api_model: + logger.info(f"OpenAI API Model: {self.openai_api_model}") + return v + raise NotConfiguredException("You should config a LLM configuration first") + + @staticmethod + def _is_valid_llm_key(k: str) -> bool: + return k and k != "YOUR_API_KEY" def _update(self): self.global_proxy = self._get("GLOBAL_PROXY") + self.openai_api_key = self._get("OPENAI_API_KEY") - self.anthropic_api_key = self._get("Anthropic_API_KEY") + self.anthropic_api_key = self._get("ANTHROPIC_API_KEY") self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") - if ( - (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) - and (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) - and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) - and (not self.open_llm_api_base) - and (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) - ): - error_info = ( - "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" - ) - val = self._get("RAISE_NOT_CONFIG_ERROR") - if val is None or val.lower() == "true": - raise NotConfiguredException(error_info) - else: # for agent - logger.warning(error_info) + self.gemini_api_key = self._get("GEMINI_API_KEY") + _ = self.get_default_llm_provider_enum() - self.openai_api_base = self._get("OPENAI_API_BASE") + self.openai_base_url = self._get("OPENAI_BASE_URL") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy self.openai_api_type = self._get("OPENAI_API_TYPE") self.openai_api_version = self._get("OPENAI_API_VERSION") self.openai_api_rpm = self._get("RPM", 3) self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4-1106-preview") self.max_tokens_rsp = self._get("MAX_TOKENS", 2048) - self.deployment_name = self._get("DEPLOYMENT_NAME") - self.deployment_id = self._get("DEPLOYMENT_ID") + self.deployment_name = self._get("DEPLOYMENT_NAME", "gpt-4") self.spark_appid = self._get("SPARK_APPID") self.spark_api_secret = self._get("SPARK_API_SECRET") @@ -100,7 +131,7 @@ class Config(metaclass=Singleton): self.fireworks_api_base = self._get("FIREWORKS_API_BASE") self.fireworks_api_model = self._get("FIREWORKS_API_MODEL") - self.claude_api_key = self._get("Anthropic_API_KEY") + self.claude_api_key = self._get("ANTHROPIC_API_KEY") self.serpapi_api_key = self._get("SERPAPI_API_KEY") self.serper_api_key = self._get("SERPER_API_KEY") self.google_api_key = self._get("GOOGLE_API_KEY") @@ -124,11 +155,11 @@ class Config(metaclass=Singleton): self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs") self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "") - self.prompt_format = self._get("PROMPT_FORMAT", "json") workspace_uid = ( self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}" ) self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False) + self.prompt_format = self._get("PROMPT_FORMAT", "json") self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) val = self._get("WORKSPACE_PATH_WITH_UID") if val and val.lower() == "true": # for agent @@ -136,6 +167,19 @@ class Config(metaclass=Singleton): self._ensure_workspace_exists() self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1) + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + def _ensure_workspace_exists(self): self.workspace_path.mkdir(parents=True, exist_ok=True) logger.debug(f"WORKSPACE_PATH set to {self.workspace_path}") @@ -158,8 +202,8 @@ class Config(metaclass=Singleton): @staticmethod def _get(*args, **kwargs): - m = OPTIONS.get() - return m.get(*args, **kwargs) + i = OPTIONS.get() + return i.get(*args, **kwargs) def get(self, key, *args, **kwargs): """Retrieve values from config/key.yaml, config/config.yaml, and environment variables. @@ -173,8 +217,8 @@ class Config(metaclass=Singleton): OPTIONS.get()[name] = value def __getattr__(self, name: str) -> Any: - m = OPTIONS.get() - return m.get(name) + i = OPTIONS.get() + return i.get(name) def set_context(self, options: dict): """Update current config""" @@ -193,8 +237,8 @@ class Config(metaclass=Singleton): def new_environ(self): """Return a new os.environ object""" env = os.environ.copy() - m = self.options - env.update({k: v for k, v in m.items() if isinstance(v, str)}) + i = self.options + env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env diff --git a/metagpt/const.py b/metagpt/const.py index 53f797001..76ddc077c 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -55,11 +55,14 @@ DATA_PATH = METAGPT_ROOT / "data" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" + UT_PATH = DATA_PATH / "ut" SWAGGER_PATH = UT_PATH / "files/api/" UT_PY_PATH = UT_PATH / "files/ut/" API_QUESTIONS_PATH = UT_PATH / "files/question/" +SERDESER_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project + TMP = METAGPT_ROOT / "tmp" SOURCE_ROOT = METAGPT_ROOT / "metagpt" @@ -110,7 +113,7 @@ COMMAND_TOKENS = 500 BRAIN_MEMORY = "BRAIN_MEMORY" SKILL_PATH = "SKILL_PATH" SERPER_API_KEY = "SERPER_API_KEY" - +DEFAULT_TOKEN_SIZE = 500 # format BASE64_FORMAT = "base64" diff --git a/metagpt/environment.py b/metagpt/environment.py index 5c3a6f97f..e0b5010d9 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -12,29 +12,85 @@ functionality is to be consolidated into the `Environment` class. """ import asyncio +from pathlib import Path from typing import Iterable, Set from pydantic import BaseModel, Field +from metagpt.config import CONFIG from metagpt.logs import logger -from metagpt.roles import Role +from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message -from metagpt.utils.common import is_subscribed +from metagpt.utils.common import is_subscribed, read_json_file, write_json_file class Environment(BaseModel): """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles - """ roles: dict[str, Role] = Field(default_factory=dict) members: dict[Role, Set] = Field(default_factory=dict) - history: str = Field(default="") # For debug + history: str = "" # For debug class Config: arbitrary_types_allowed = True + def __init__(self, **kwargs): + roles = [] + for role_key, role in kwargs.get("roles", {}).items(): + current_role = kwargs["roles"][role_key] + if isinstance(current_role, dict): + item_class_name = current_role.get("builtin_class_name", None) + for name, subclass in role_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_role = subclass(**current_role) + break + kwargs["roles"][role_key] = current_role + roles.append(current_role) + super().__init__(**kwargs) + + self.add_roles(roles) # add_roles again to init the Role.set_env + + def serialize(self, stg_path: Path): + roles_path = stg_path.joinpath("roles.json") + roles_info = [] + for role_key, role in self.roles.items(): + roles_info.append( + { + "role_class": role.__class__.__name__, + "module_name": role.__module__, + "role_name": role.name, + "role_sub_tags": list(self.members.get(role)), + } + ) + role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) + write_json_file(roles_path, roles_info) + + history_path = stg_path.joinpath("history.json") + write_json_file(history_path, {"content": self.history}) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Environment": + """stg_path: ./storage/team/environment/""" + roles_path = stg_path.joinpath("roles.json") + roles_info = read_json_file(roles_path) + roles = [] + for role_info in roles_info: + # role stored in ./environment/roles/{role_class}_{role_name} + role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}") + role = Role.deserialize(role_path) + roles.append(role) + + history = read_json_file(stg_path.joinpath("history.json")) + history = history.get("content") + + environment = Environment(**{"history": history}) + environment.add_roles(roles) + + return environment + def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment @@ -111,3 +167,8 @@ class Environment(BaseModel): def set_subscription(self, obj, tags): """Set the labels for message to be consumed by the object""" self.members[obj] = tags + + @staticmethod + def archive(auto_archive=True): + if auto_archive and CONFIG.git_repo: + CONFIG.git_repo.archive() diff --git a/metagpt/llm.py b/metagpt/llm.py index ad4bf6cb2..8763642f0 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -4,50 +4,16 @@ @Time : 2023/5/11 14:45 @Author : alexanderwu @File : llm.py -@Modified By: mashenquan, 2023 """ -from metagpt.config import CONFIG -from metagpt.provider import LLMType +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.metagpt_llm_api import MetaGPTLLMAPI -from metagpt.provider.open_llm_api import OpenLLMGPTAPI -from metagpt.provider.openai_api import OpenAIGPTAPI -from metagpt.provider.spark_api import SparkAPI -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -# Used in agents -class LLMFactory: - @staticmethod - def new_llm() -> "BaseGPTAPI": - # Determine which type of LLM to use based on the validity of the key. - if CONFIG.spark_api_key: - return SparkAPI() - elif CONFIG.zhipuai_api_key: - return ZhiPuAIGPTAPI() - elif CONFIG.open_llm_api_base: - return OpenLLMGPTAPI() - elif CONFIG.fireworks_api_key: - return FireWorksGPTAPI() - - # MetaGPT uses the same parameters as OpenAI. - constructors = { - LLMType.OPENAI.value: OpenAIGPTAPI, - LLMType.METAGPT.value: MetaGPTLLMAPI, - } - constructor = constructors.get(CONFIG.LLM_TYPE) - if constructor: - return constructor() - - raise RuntimeError("You should config a LLM configuration first") - - -# Used in metagpt -def LLM() -> "BaseGPTAPI": - """initialize different LLM instance according to the key field existence""" - return LLMFactory.new_llm() +def LLM(provider: LLMProviderEnum = CONFIG.get_default_llm_provider_enum()) -> BaseGPTAPI: + """get the default llm provider""" + return LLM_REGISTRY.get_provider(provider) diff --git a/metagpt/manager.py b/metagpt/manager.py deleted file mode 100644 index a063608be..000000000 --- a/metagpt/manager.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:42 -@Author : alexanderwu -@File : manager.py -""" -from metagpt.llm import LLM -from metagpt.logs import logger -from metagpt.schema import Message - - -class Manager: - def __init__(self, llm: LLM = LLM()): - self.llm = llm # Large Language Model - self.role_directions = { - "User": "Product Manager", - "Product Manager": "Architect", - "Architect": "Engineer", - "Engineer": "QA Engineer", - "QA Engineer": "Product Manager", - } - self.prompt_template = """ - Given the following message: - {message} - - And the current status of roles: - {roles} - - Which role should handle this message? - """ - - async def handle(self, message: Message, environment): - """ - 管理员处理信息,现在简单的将信息递交给下一个人 - The administrator processes the information, now simply passes the information on to the next person - :param message: - :param environment: - :return: - """ - # Get all roles from the environment - roles = environment.get_roles() - # logger.debug(f"{roles=}, {message=}") - - # Build a context for the LLM to understand the situation - # context = { - # "message": str(message), - # "roles": {role.name: role.get_info() for role in roles}, - # } - # Ask the LLM to decide which role should handle the message - # chosen_role_name = self.llm.ask(self.prompt_template.format(context)) - - # FIXME: 现在通过简单的字典决定流向,但之后还是应该有思考过程 - # The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards - next_role_profile = self.role_directions[message.role] - # logger.debug(f"{next_role_profile}") - for _, role in roles.items(): - if next_role_profile == role.profile: - next_role = role - break - else: - logger.error(f"No available role can handle message: {message}.") - return - - # Find the chosen role and handle the message - return await next_role.handle(message) diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 5e68e99e5..1497b8910 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -5,6 +5,10 @@ @Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ +from typing import Optional + +from pydantic import Field + from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage @@ -18,11 +22,12 @@ class LongTermMemory(Memory): - update memory when it changed """ - def __init__(self): - self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() - self.rc = None # RoleContext - self.msg_from_recover = False + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) + rc: Optional["RoleContext"] = None + msg_from_recover: bool = False + + class Config: + arbitrary_types_allowed = True def recover_memory(self, role_id: str, rc: "RoleContext"): messages = self.memory_storage.recover_memory(role_id) @@ -38,7 +43,7 @@ class LongTermMemory(Memory): self.msg_from_recover = False def add(self, message: Message): - super(LongTermMemory, self).add(message) + super().add(message) for action in self.rc.watch: if message.cause_by == action and not self.msg_from_recover: # currently, only add role's watching messages to its memory_storage @@ -51,7 +56,7 @@ class LongTermMemory(Memory): 1. find the short-term memory(stm) news 2. furthermore, filter out similar messages based on ltm(long-term memory), get the final news """ - stm_news = super(LongTermMemory, self).find_news(observed, k=k) # shot-term memory news + stm_news = super().find_news(observed, k=k) # shot-term memory news if not self.memory_storage.is_initialized: # memory_storage hasn't initialized, use default `find_news` to get stm_news return stm_news @@ -65,9 +70,9 @@ class LongTermMemory(Memory): return ltm_news[-k:] def delete(self, message: Message): - super(LongTermMemory, self).delete(message) + super().delete(message) # TODO delete message in memory_storage def clear(self): - super(LongTermMemory, self).clear() + super().clear() self.memory_storage.clean() diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 7a6aa1c45..d964cc1dc 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -7,19 +7,50 @@ @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ from collections import defaultdict +from pathlib import Path from typing import Iterable, Set +from pydantic import BaseModel, Field + from metagpt.schema import Message -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import ( + any_to_str, + any_to_str_set, + read_json_file, + write_json_file, +) -class Memory: +class Memory(BaseModel): """The most basic memory: super-memory""" - def __init__(self): - """Initialize an empty storage list and an empty index dictionary""" - self.storage: list[Message] = [] - self.index: dict[str, list[Message]] = defaultdict(list) + storage: list[Message] = [] + index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) + + def __init__(self, **kwargs): + index = kwargs.get("index", {}) + new_index = defaultdict(list) + for action_str, value in index.items(): + new_index[action_str] = [Message(**item_dict) for item_dict in value] + kwargs["index"] = new_index + super(Memory, self).__init__(**kwargs) + self.index = new_index + + def serialize(self, stg_path: Path): + """stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + memory_path = stg_path.joinpath("memory.json") + storage = self.dict() + write_json_file(memory_path, storage) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Memory": + """stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + memory_path = stg_path.joinpath("memory.json") + + memory_dict = read_json_file(memory_path) + memory = Memory(**memory_dict) + + return memory def add(self, message: Message): """Add a new message to storage, while updating the index""" @@ -41,6 +72,16 @@ class Memory: """Return all messages containing a specified content""" return [message for message in self.storage if content in message.content] + def delete_newest(self) -> "Message": + """delete the newest message from the storage""" + if len(self.storage) > 0: + newest_msg = self.storage.pop() + if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: + self.index[newest_msg.cause_by].remove(newest_msg) + else: + newest_msg = None + return newest_msg + def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" self.storage.remove(message) diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index fc2f0a419..3017c23ad 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -61,7 +61,7 @@ class MemoryStorage(FaissStore): return index_fpath, storage_fpath def persist(self): - super(MemoryStorage, self).persist() + super().persist() logger.debug(f"Agent {self.role_id} persist memory into local") def add(self, message: Message) -> bool: diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 3517e1376..a9f46eb03 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -4,23 +4,12 @@ @Time : 2023/5/5 22:59 @Author : alexanderwu @File : __init__.py -@Modified By: mashenquan, 2023-12-15. Add LLMType """ -from enum import Enum +from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.open_llm_api import OpenLLMGPTAPI +from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -class LLMType(Enum): - OPENAI = "OpenAI" - METAGPT = "MetaGPT" - UNKNOWN = "UNKNOWN" - - @classmethod - def get(cls, value): - for member in cls: - if member.value == value: - return member - return cls.UNKNOWN - - @classmethod - def __missing__(cls, value): - return cls.UNKNOWN +__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"] diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 03802a716..f5b06c855 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -14,7 +14,7 @@ from metagpt.config import CONFIG class Claude2: def ask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", @@ -24,7 +24,7 @@ class Claude2: return res.completion async def aask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py new file mode 100644 index 000000000..ec5eed3f6 --- /dev/null +++ b/metagpt/provider/azure_openai_api.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/5 23:08 +@Author : alexanderwu +@File : openai.py +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation; + Change cost control from global to company level. +@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout. +@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x. +""" + + +from openai import AsyncAzureOpenAI, AzureOpenAI +from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper + +from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter + + +@register_provider(LLMProviderEnum.AZURE_OPENAI) +class AzureOpenAIGPTAPI(OpenAIGPTAPI): + """ + Check https://platform.openai.com/examples for examples + """ + + def __init__(self): + self.config: Config = CONFIG + self.__init_openai() + self.auto_max_tokens = False + # https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix + self._client = AsyncAzureOpenAI( + api_key=CONFIG.openai_api_key, + api_version=CONFIG.openai_api_version, + azure_endpoint=CONFIG.openai_api_base, + ) + RateLimiter.__init__(self, rpm=self.rpm) + + def _make_client(self): + kwargs, async_kwargs = self._make_client_kwargs() + self.client = AzureOpenAI(**kwargs) + self.async_client = AsyncAzureOpenAI(**async_kwargs) + + def _make_client_kwargs(self) -> (dict, dict): + kwargs = dict( + api_key=self.config.openai_api_key, + api_version=self.config.openai_api_version, + azure_endpoint=self.config.openai_base_url, + ) + async_kwargs = kwargs.copy() + + # to use proxy, openai v1 needs http_client + proxy_params = self._get_proxy_params() + if proxy_params: + kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) + async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + + return kwargs, async_kwargs + + def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: + kwargs = { + "messages": messages, + "max_tokens": self.get_max_tokens(messages), + "n": 1, + "stop": None, + "temperature": 0.3, + "model": CONFIG.deployment_id, + } + if configs: + kwargs.update(configs) + try: + default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0 + except ValueError: + default_timeout = 0 + kwargs["timeout"] = max(default_timeout, timeout) + + return kwargs diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 6625cda97..bfe85f490 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -4,18 +4,18 @@ import openai -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter -from metagpt.utils.cost_manager import CostManager +@register_provider(LLMProviderEnum.FIREWORKS) class FireWorksGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_fireworks(CONFIG) self.llm = openai self.model = CONFIG.fireworks_api_model self.auto_max_tokens = False - self._cost_manager = CostManager() RateLimiter.__init__(self, rpm=self.rpm) def __init_fireworks(self, config: "Config"): diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..eb91cc32b --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +import google.generativeai as genai +from google.ai import generativelanguage as glm +from google.generativeai.generative_models import GenerativeModel +from google.generativeai.types import content_types +from google.generativeai.types.generation_types import ( + AsyncGenerateContentResponse, + GenerateContentResponse, + GenerationConfig, +) +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) + + +@register_provider(LLMProviderEnum.GEMINI) +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + def __init__(self): + self.use_system_prompt = False # google gemini has no system prompt when use api + + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = GeminiGenerativeModel(model_name=self.model) + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + # Not to change BaseGPTAPI default functions but update with Gemini's conversation format. + # You should follow the format. + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} + return kwargs + + def _update_costs(self, usage: dict): + """update each request's token cost""" + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error(f"google gemini updats costs failed! exp: {e}") + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async( + **self._const_kwargs(messages, stream=True) + ) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(min=1, max=60), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """response in async with stream or non-stream mode""" + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py new file mode 100644 index 000000000..2b3ef93a3 --- /dev/null +++ b/metagpt/provider/llm_provider_registry.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 17:26 +@Author : alexanderwu +@File : llm_provider_registry.py +""" +from metagpt.config import LLMProviderEnum + + +class LLMProviderRegistry: + def __init__(self): + self.providers = {} + + def register(self, key, provider_cls): + self.providers[key] = provider_cls + + def get_provider(self, enum: LLMProviderEnum): + """get provider instance according to the enum""" + return self.providers[enum]() + + +# Registry instance +LLM_REGISTRY = LLMProviderRegistry() + + +def register_provider(key): + """register provider to registry""" + + def decorator(cls): + LLM_REGISTRY.register(key, cls) + return cls + + return decorator diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py new file mode 100644 index 000000000..00a42ee2a --- /dev/null +++ b/metagpt/provider/metagpt_api.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/5 23:08 +@Author : alexanderwu +@File : metagpt_api.py +@Desc : MetaGPT LLM provider. +""" +from metagpt.config import LLMProviderEnum +from metagpt.provider import OpenAIGPTAPI +from metagpt.provider.llm_provider_registry import register_provider + + +@register_provider(LLMProviderEnum.METAGPT) +class METAGPTAPI(OpenAIGPTAPI): + def __init__(self): + super().__init__() diff --git a/metagpt/provider/metagpt_llm_api.py b/metagpt/provider/metagpt_llm_api.py deleted file mode 100644 index 994fc39ff..000000000 --- a/metagpt/provider/metagpt_llm_api.py +++ /dev/null @@ -1,279 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@Time : 2023/8/30 -@Author : mashenquan -@File : metagpt_llm_api.py -@Desc : MetaGPT LLM related APIs -""" - -from metagpt.provider.openai_api import OpenAIGPTAPI - -# from metagpt.provider.base_gpt_api import BaseGPTAPI -# from metagpt.provider.openai_api import RateLimiter - - -class MetaGPTLLMAPI(OpenAIGPTAPI): - """MetaGPT LLM api""" - - def __init__(self): - super(MetaGPTLLMAPI, self).__init__() - - # def __init__(self): - # self.__init_openai(CONFIG) - # self.llm = openai - # self.model = CONFIG.openai_api_model - # self.auto_max_tokens = False - # self._cost_manager = CostManager() - # RateLimiter.__init__(self, rpm=self.rpm) - # - # def __init_openai(self, config): - # openai.api_key = config.openai_api_key - # if config.openai_api_base: - # openai.api_base = config.openai_api_base - # if config.openai_api_type: - # openai.api_type = config.openai_api_type - # openai.api_version = config.openai_api_version - # self.rpm = int(config.get("RPM", 10)) - # - # async def _achat_completion_stream(self, messages: list[dict]) -> str: - # response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True) - # - # # create variables to collect the stream of chunks - # collected_chunks = [] - # collected_messages = [] - # # iterate through the stream of events - # async for chunk in response: - # collected_chunks.append(chunk) # save the event response - # choices = chunk["choices"] - # if len(choices) > 0: - # chunk_message = chunk["choices"][0].get("delta", {}) # extract the message - # collected_messages.append(chunk_message) # save the message - # if "content" in chunk_message: - # print(chunk_message["content"], end="") - # print() - # - # full_reply_content = "".join([m.get("content", "") for m in collected_messages]) - # usage = self._calc_usage(messages, full_reply_content) - # self._update_costs(usage) - # return full_reply_content - # - # def _cons_kwargs(self, messages: list[dict], **configs) -> dict: - # kwargs = { - # "messages": messages, - # "max_tokens": self.get_max_tokens(messages), - # "n": 1, - # "stop": None, - # "temperature": 0.3, - # "timeout": 3, - # } - # if configs: - # kwargs.update(configs) - # - # if CONFIG.openai_api_type == "azure": - # if CONFIG.deployment_name and CONFIG.deployment_id: - # raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model") - # elif not CONFIG.deployment_name and not CONFIG.deployment_id: - # raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter") - # kwargs_mode = ( - # {"engine": CONFIG.deployment_name} - # if CONFIG.deployment_name - # else {"deployment_id": CONFIG.deployment_id} - # ) - # else: - # kwargs_mode = {"model": self.model} - # kwargs.update(kwargs_mode) - # return kwargs - # - # async def _achat_completion(self, messages: list[dict]) -> dict: - # rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages)) - # self._update_costs(rsp.get("usage")) - # return rsp - # - # def _chat_completion(self, messages: list[dict]) -> dict: - # rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages)) - # self._update_costs(rsp) - # return rsp - # - # def completion(self, messages: list[dict]) -> dict: - # # if isinstance(messages[0], Message): - # # messages = self.messages_to_dict(messages) - # return self._chat_completion(messages) - # - # async def acompletion(self, messages: list[dict]) -> dict: - # # if isinstance(messages[0], Message): - # # messages = self.messages_to_dict(messages) - # return await self._achat_completion(messages) - # - # @retry( - # wait=wait_random_exponential(min=1, max=60), - # stop=stop_after_attempt(6), - # after=after_log(logger, logger.level("WARNING").name), - # retry=retry_if_exception_type(APIConnectionError), - # retry_error_callback=log_and_reraise, - # ) - # async def acompletion_text(self, messages: list[dict], stream=False) -> str: - # """when streaming, print each token in place.""" - # if stream: - # return await self._achat_completion_stream(messages) - # rsp = await self._achat_completion(messages) - # return self.get_choice_text(rsp) - # - # def _func_configs(self, messages: list[dict], **kwargs) -> dict: - # """ - # Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create - # """ - # if "tools" not in kwargs: - # configs = { - # "tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}], - # "tool_choice": GENERAL_TOOL_CHOICE, - # } - # kwargs.update(configs) - # - # return self._cons_kwargs(messages, **kwargs) - # - # def _chat_completion_function(self, messages: list[dict], **kwargs) -> dict: - # rsp = self.llm.ChatCompletion.create(**self._func_configs(messages, **kwargs)) - # self._update_costs(rsp.get("usage")) - # return rsp - # - # async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> dict: - # rsp = await self.llm.ChatCompletion.acreate(**self._func_configs(messages, **chat_configs)) - # self._update_costs(rsp.get("usage")) - # return rsp - # - # def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: - # """convert messages to list[dict].""" - # if isinstance(messages, list): - # messages = [Message(msg) if isinstance(msg, str) else msg for msg in messages] - # return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages] - # - # if isinstance(messages, Message): - # messages = [messages.to_dict()] - # elif isinstance(messages, str): - # messages = [{"role": "user", "content": messages}] - # else: - # raise ValueError( - # f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!" - # ) - # return messages - # - # def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: - # """Use function of tools to ask a code. - # - # Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create - # - # Examples: - # - # >>> llm = OpenAIGPTAPI() - # >>> llm.ask_code("Write a python hello world code.") - # {'language': 'python', 'code': "print('Hello, World!')"} - # >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] - # >>> llm.ask_code(msg) - # {'language': 'python', 'code': "print('Hello, World!')"} - # """ - # messages = self._process_message(messages) - # rsp = self._chat_completion_function(messages, **kwargs) - # return self.get_choice_function_arguments(rsp) - # - # async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: - # """Use function of tools to ask a code. - # - # Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create - # - # Examples: - # - # >>> llm = OpenAIGPTAPI() - # >>> rsp = await llm.ask_code("Write a python hello world code.") - # >>> rsp - # {'language': 'python', 'code': "print('Hello, World!')"} - # >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] - # >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - # """ - # messages = self._process_message(messages) - # rsp = await self._achat_completion_function(messages, **kwargs) - # return self.get_choice_function_arguments(rsp) - # - # def _calc_usage(self, messages: list[dict], rsp: str) -> dict: - # usage = {} - # if CONFIG.calc_usage: - # try: - # prompt_tokens = count_message_tokens(messages, self.model) - # completion_tokens = count_string_tokens(rsp, self.model) - # usage["prompt_tokens"] = prompt_tokens - # usage["completion_tokens"] = completion_tokens - # return usage - # except Exception as e: - # logger.error("usage calculation failed!", e) - # else: - # return usage - # - # async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]: - # """Return full JSON""" - # split_batches = self.split_batches(batch) - # all_results = [] - # - # for small_batch in split_batches: - # logger.info(small_batch) - # await self.wait_if_needed(len(small_batch)) - # - # future = [self.acompletion(prompt) for prompt in small_batch] - # results = await asyncio.gather(*future) - # logger.info(results) - # all_results.extend(results) - # - # return all_results - # - # async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]: - # """Only return plain text""" - # raw_results = await self.acompletion_batch(batch) - # results = [] - # for idx, raw_result in enumerate(raw_results, start=1): - # result = self.get_choice_text(raw_result) - # results.append(result) - # logger.info(f"Result of task {idx}: {result}") - # return results - # - # def _update_costs(self, usage: dict): - # if CONFIG.calc_usage: - # try: - # prompt_tokens = int(usage["prompt_tokens"]) - # completion_tokens = int(usage["completion_tokens"]) - # self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - # except Exception as e: - # logger.error("updating costs failed!", e) - # - # def get_costs(self) -> Costs: - # return self._cost_manager.get_costs() - # - # def get_max_tokens(self, messages: list[dict]): - # if not self.auto_max_tokens: - # return CONFIG.max_tokens_rsp - # return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) - # - # def moderation(self, content: Union[str, list[str]]): - # try: - # if not content: - # logger.error("content cannot be empty!") - # else: - # rsp = self._moderation(content=content) - # return rsp - # except Exception as e: - # logger.error(f"moderating failed:{e}") - # - # def _moderation(self, content: Union[str, list[str]]): - # rsp = self.llm.Moderation.create(input=content) - # return rsp - # - # async def amoderation(self, content: Union[str, list[str]]): - # try: - # if not content: - # logger.error("content cannot be empty!") - # else: - # rsp = await self._amoderation(content=content) - # return rsp - # except Exception as e: - # logger.error(f"moderating failed:{e}") - # - # async def _amoderation(self, content: Union[str, list[str]]): - # rsp = await self.llm.Moderation.acreate(input=content) - # return rsp diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index cd30c4a58..2e8c03ba1 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -2,40 +2,39 @@ # -*- coding: utf-8 -*- # @Desc : self-host open llm model with openai-compatible interface -from metagpt.config import CONFIG -from metagpt.logs import logger + +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter -from metagpt.utils.cost_manager import CostManager - - -class OpenLLMCostManager(CostManager): - """open llm model is self-host, it's free and without cost""" - - def update_cost(self, prompt_tokens, completion_tokens, model): - """ - Update the total cost, prompt tokens, and completion tokens. - - Args: - prompt_tokens (int): The number of tokens used in the prompt. - completion_tokens (int): The number of tokens used in the completion. - model (str): The model used for the API call. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - - logger.info( - f"Max budget: ${CONFIG.max_budget:.3f} | " - f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" - ) - CONFIG.total_cost = self.total_cost + +# class OpenLLMCostManager(CostManager): +# """open llm model is self-host, it's free and without cost""" +# +# def update_cost(self, prompt_tokens, completion_tokens, model): +# """ +# Update the total cost, prompt tokens, and completion tokens. +# +# Args: +# prompt_tokens (int): The number of tokens used in the prompt. +# completion_tokens (int): The number of tokens used in the completion. +# model (str): The model used for the API call. +# """ +# self.total_prompt_tokens += prompt_tokens +# self.total_completion_tokens += completion_tokens +# +# logger.info( +# f"Max budget: ${CONFIG.max_budget:.3f} | " +# f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" +# ) +# CONFIG.total_cost = self.total_cost +@register_provider(LLMProviderEnum.OPEN_LLM) class OpenLLMGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_openllm(CONFIG) self.model = CONFIG.open_llm_api_model self.auto_max_tokens = False - self._cost_manager = OpenLLMCostManager() RateLimiter.__init__(self, rpm=self.rpm) def __init_openllm(self, config: "Config"): diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 5e9c9fc4d..ca130ce15 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -10,12 +10,15 @@ """ import asyncio +import json import time -from typing import Union +from typing import List, Union import openai -from openai import APIConnectionError, AsyncAzureOpenAI, AsyncOpenAI, RateLimitError +from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI, RateLimitError +from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion, ChatCompletionChunk from tenacity import ( after_log, retry, @@ -24,13 +27,15 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.const import DEFAULT_MAX_TOKENS from metagpt.logs import logger -from metagpt.provider import LLMType from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE +from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.cost_manager import Costs +from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, count_string_tokens, @@ -74,25 +79,18 @@ See FAQ 5.8 raise retry_state.outcome.exception() +@register_provider(LLMProviderEnum.OPENAI) class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ Check https://platform.openai.com/examples for examples """ def __init__(self): - self.model = CONFIG.openai_api_model + self.config: Config = CONFIG + self.__init_openai() self.auto_max_tokens = False - self.rpm = int(CONFIG.get("RPM", 10)) - if CONFIG.openai_api_type == "azure": - # https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix - self._client = AsyncAzureOpenAI( - api_key=CONFIG.openai_api_key, - api_version=CONFIG.openai_api_version, - azure_endpoint=CONFIG.openai_api_base, - ) - else: - # https://github.com/openai/openai-python#async-usage - self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base) + # https://github.com/openai/openai-python#async-usage + self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base) RateLimiter.__init__(self, rpm=self.rpm) async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: @@ -103,6 +101,59 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): chunk_message = chunk.choices[0].delta.content or "" # extract the message yield chunk_message + def __init_openai(self): + self.rpm = int(self.config.get("RPM", 10)) + self._make_client() + + def _make_client(self): + kwargs, async_kwargs = self._make_client_kwargs() + self.client = OpenAI(**kwargs) + self.async_client = AsyncOpenAI(**async_kwargs) + + def _make_client_kwargs(self) -> (dict, dict): + kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url) + async_kwargs = kwargs.copy() + + # to use proxy, openai v1 needs http_client + proxy_params = self._get_proxy_params() + if proxy_params: + kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) + async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + + return kwargs, async_kwargs + + def _get_proxy_params(self) -> dict: + params = {} + if self.config.openai_proxy: + params = {"proxies": self.config.openai_proxy} + if self.config.openai_base_url: + params["base_url"] = self.config.openai_base_url + + return params + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( + **self._cons_kwargs(messages), stream=True + ) + + # create variables to collect the stream of chunks + collected_chunks = [] + collected_messages = [] + # iterate through the stream of events + async for chunk in response: + collected_chunks.append(chunk) # save the event response + if chunk.choices: + chunk_message = chunk.choices[0].delta # extract the message + collected_messages.append(chunk_message) # save the message + if chunk_message.content: + print(chunk_message.content, end="") + print() + + full_reply_content = "".join([m.content for m in collected_messages if m.content]) + usage = self._calc_usage(messages, full_reply_content) + self._update_costs(usage) + return full_reply_content + def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: kwargs = { "messages": messages, @@ -110,14 +161,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): "n": 1, "stop": None, "temperature": 0.3, + "model": self.config.openai_api_model, } if configs: kwargs.update(configs) - - if CONFIG.openai_api_type == "azure": - kwargs["model"] = CONFIG.deployment_id - else: - kwargs["model"] = self.model try: default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0 except ValueError: @@ -126,19 +173,17 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return kwargs - async def _achat_completion(self, messages: list[dict], timeout=3) -> dict: + async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion: kwargs = self._cons_kwargs(messages, timeout=timeout) - rsp = await self._client.chat.completions.create(**kwargs) + rsp: ChatCompletion = await self._client.chat.completions.create(**kwargs) self._update_costs(rsp.usage) - return rsp.dict() + return rsp - def completion(self, messages: list[dict], timeout=3) -> dict: + def completion(self, messages: list[dict], timeout=3) -> ChatCompletion: loop = self.get_event_loop() return loop.run_until_complete(self.acompletion(messages, timeout=timeout)) - async def acompletion(self, messages: list[dict], timeout=3) -> dict: - # if isinstance(messages[0], Message): - # messages = self.messages_to_dict(messages) + async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion: return await self._achat_completion(messages, timeout=timeout) @retry( @@ -188,20 +233,20 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs) - def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> dict: + def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion: loop = self.get_event_loop() return loop.run_until_complete(self._achat_completion_function(messages=messages, timeout=timeout, **kwargs)) - async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> dict: + async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion: kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs) - rsp = await self._client.chat.completions.create(**kwargs) + rsp: ChatCompletion = await self._client.chat.completions.create(**kwargs) self._update_costs(rsp.usage) - return rsp.dict() + return rsp def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: """convert messages to list[dict].""" if isinstance(messages, list): - messages = [Message(msg) if isinstance(msg, str) else msg for msg in messages] + messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages] return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages] if isinstance(messages, Message): @@ -269,7 +314,35 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.error(f"{self.model} usage calculation failed!", e) return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[dict]: + def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict: + """Required to provide the first function arguments of choice. + + :return dict: return the first function arguments of choice, for example, + {'language': 'python', 'code': "print('Hello, World!')"} + """ + try: + return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments) + except json.JSONDecodeError: + return {} + + def get_choice_text(self, rsp: ChatCompletion) -> str: + """Required to provide the first text of choice""" + return rsp.choices[0].message.content if rsp.choices else "" + + def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: + usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + if not CONFIG.calc_usage: + return usage + + try: + usage.prompt_tokens = count_message_tokens(messages, self.model) + usage.completion_tokens = count_string_tokens(rsp, self.model) + except Exception as e: + logger.error(f"usage calculation failed!: {e}") + + return usage + + async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]: """Return full JSON""" split_batches = self.split_batches(batch) all_results = [] @@ -296,11 +369,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return results def _update_costs(self, usage: CompletionUsage): - if CONFIG.calc_usage: + if CONFIG.calc_usage and usage: try: - prompt_tokens = usage.prompt_tokens - completion_tokens = usage.completion_tokens - CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) except Exception as e: logger.error("updating costs failed!", e) @@ -316,19 +387,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): loop = self.get_event_loop() loop.run_until_complete(self.amoderation(content=content)) + @handle_exception async def amoderation(self, content: Union[str, list[str]]): - try: - if not content: - logger.error("content cannot be empty!") - else: - rsp = await self._amoderation(content=content) - return rsp - except Exception as e: - logger.error(f"moderating failed:{e}") - - async def _amoderation(self, content: Union[str, list[str]]): - rsp = await self._client.moderations.create(input=content) - return rsp + return await self._client.moderations.create(input=content) async def close(self): """Close connection""" @@ -349,8 +410,73 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): else: raise e - async def get_summary(self, text: str, max_words=200, keep_language: bool = False, **kwargs) -> str: - from metagpt.memory.brain_memory import BrainMemory + async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs) -> str: + max_token_count = DEFAULT_MAX_TOKENS + max_count = 100 + text_length = len(text) + if limit > 0 and text_length < limit: + return text + summary = "" + while max_count > 0: + if text_length < max_token_count: + summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) + break - memory = BrainMemory(llm_type=LLMType.OPENAI.value, historical_summary=text, cacheable=False) - return await memory.summarize(llm=self, max_words=max_words, keep_language=keep_language) + padding_size = 20 if max_token_count > 20 else 0 + text_windows = self.split_texts(text, window_size=max_token_count - padding_size) + part_max_words = min(int(max_words / len(text_windows)) + 1, 100) + summaries = [] + for ws in text_windows: + response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) + summaries.append(response) + if len(summaries) == 1: + summary = summaries[0] + break + + # Merged and retry + text = "\n".join(summaries) + text_length = len(text) + + max_count -= 1 # safeguard + return summary + + async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): + """Generate text summary""" + if len(text) < max_words: + return text + if keep_language: + command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." + else: + command = f"Translate the above content into a summary of less than {max_words} words." + msg = text + "\n\n" + command + logger.debug(f"summary ask:{msg}") + response = await self.aask(msg=msg, system_msgs=[]) + logger.debug(f"summary rsp: {response}") + return response + + @staticmethod + def split_texts(text: str, window_size) -> List[str]: + """Splitting long text into sliding windows text""" + if window_size <= 0: + window_size = DEFAULT_TOKEN_SIZE + total_len = len(text) + if total_len <= window_size: + return [text] + + padding_size = 20 if window_size > 20 else 0 + windows = [] + idx = 0 + data_len = window_size - padding_size + while idx < total_len: + if window_size + idx > total_len: # 不足一个滑窗 + windows.append(text[idx:]) + break + # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] + # window_size=3, padding_size=1: + # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... + # idx=2, | idx=5 | idx=8 | ... + w = text[idx : idx + window_size] + windows.append(w) + idx += data_len + + return windows diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprecess/base_postprecess_plugin.py index 0d1cfbb11..46646be91 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprecess/base_postprecess_plugin.py @@ -4,7 +4,6 @@ from typing import Union -from metagpt.logs import logger from metagpt.utils.repair_llm_raw_output import ( RepairType, extract_content_from_output, @@ -44,7 +43,7 @@ class BasePostPrecessPlugin(object): def run_retry_parse_json_text(self, content: str) -> Union[dict, list]: """inherited class can re-implement the function""" - logger.info(f"extracted json CONTENT from output:\n{content}") + # logger.info(f"extracted json CONTENT from output:\n{content}") parsed_data = retry_parse_json_text(output=content) # should use output=content return parsed_data diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 60c86f4dc..484fa7956 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -19,11 +19,13 @@ from wsgiref.handlers import format_date_time import websocket # 使用websocket_client -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider +@register_provider(LLMProviderEnum.SPARK) class SparkAPI(BaseGPTAPI): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index ff8e5531e..54f0ddcbb 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -15,12 +15,12 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI -from metagpt.utils.cost_manager import CostManager class ZhiPuEvent(Enum): @@ -30,6 +30,7 @@ class ZhiPuEvent(Enum): FINISH = "finish" +@register_provider(LLMProviderEnum.ZHIPUAI) class ZhiPuAIGPTAPI(BaseGPTAPI): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` @@ -42,7 +43,6 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): self.__init_zhipuai(CONFIG) self.llm = ZhiPuModelAPI self.model = "chatglm_turbo" # so far only one model, just use it - self._cost_manager = CostManager() def __init_zhipuai(self, config: CONFIG): assert config.zhipuai_api_key @@ -60,9 +60,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): try: prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index ff34257a6..9f3a1bac4 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -336,4 +336,4 @@ def error(): if __name__ == "__main__": - error() + main() diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index fce6c3425..c6ceaccb7 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -8,7 +8,7 @@ from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class Architect(Role): @@ -22,17 +22,16 @@ class Architect(Role): constraints (str): Constraints or guidelines for the architect. """ - def __init__( - self, - name: str = "Bob", - profile: str = "Architect", - goal: str = "design a concise, usable, complete software system", - constraints: str = "make sure the architecture is simple enough and use appropriate open source libraries." - "Use same language as user requirement", - ) -> None: - """Initializes the Architect with given attributes.""" - super().__init__(name, profile, goal, constraints) + name: str = "Bob" + profile: str = "Architect" + goal: str = "design a concise, usable, complete software system" + constraints: str = ( + "make sure the architecture is simple enough and use appropriate open source " + "libraries. Use same language as user requirement" + ) + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) # Initialize actions specific to the Architect role self._init_actions([WriteDesign]) diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 188182d47..777f62731 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -5,6 +5,8 @@ @Author : alexanderwu @File : sales.py """ +from typing import Optional + from metagpt.roles import Sales # from metagpt.actions import SearchAndSummarize @@ -24,5 +26,11 @@ DESC = """ class CustomerService(Sales): - def __init__(self, name="Xiaomei", profile="Human customer service", desc=DESC, store=None): - super().__init__(name, profile, desc=desc, store=store) + name: str = "Xiaomei" + profile: str = "Human customer service" + desc: str = DESC + + store: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index f5286e450..12deaa5bb 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -16,6 +16,7 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ + from __future__ import annotations import json @@ -67,24 +68,25 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. """ - def __init__( - self, - name: str = "Alex", - profile: str = "Engineer", - goal: str = "write elegant, readable, extensible, efficient code", - constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " - "Use same language as user requirement", - n_borg: int = 1, - use_code_review: bool = False, - ) -> None: - """Initializes the Engineer role with given attributes.""" - super().__init__(name, profile, goal, constraints) - self.use_code_review = use_code_review + name: str = "Alex" + profile: str = "Engineer" + goal: str = "write elegant, readable, extensible, efficient code" + constraints: str = ( + "the code should conform to standards like google-style and be modular and maintainable. " + "Use same language as user requirement" + ) + n_borg: int = 1 + use_code_review: bool = False + code_todos: list = [] + summarize_todos = [] + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._init_actions([WriteCode]) self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug]) self.code_todos = [] self.summarize_todos = [] - self.n_borg = n_borg self._next_todo = any_to_name(WriteCode) @staticmethod @@ -307,4 +309,5 @@ class Engineer(Role): @property def todo(self) -> str: + """AgentStore uses this attribute to display to the user what actions the current role should take.""" return self._next_todo diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index f022237f5..0f18c9cb2 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -7,10 +7,11 @@ @Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. """ + from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.config import CONFIG -from metagpt.roles import Role +from metagpt.roles.role import Role from metagpt.utils.common import any_to_name @@ -25,23 +26,13 @@ class ProductManager(Role): constraints (str): Constraints or limitations for the product manager. """ - def __init__( - self, - name: str = "Alice", - profile: str = "Product Manager", - goal: str = "efficiently create a successful product", - constraints: str = "use same language as user requirement", - ) -> None: - """ - Initializes the ProductManager role with given attributes. + name: str = "Alice" + profile: str = "Product Manager" + goal: str = "efficiently create a successful product that meets market demands and user expectations" + constraints: str = "utilize the same language as the user requirements for seamless communication" - Args: - name (str): Name of the product manager. - profile (str): Role profile. - goal (str): Goal of the product manager. - constraints (str): Constraints or limitations for the product manager. - """ - super().__init__(name, profile, goal, constraints) + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) self._init_actions([PrepareDocuments, WritePRD]) self._watch([UserRequirement, PrepareDocuments]) @@ -61,4 +52,5 @@ class ProductManager(Role): @property def todo(self) -> str: + """AgentStore uses this attribute to display to the user what actions the current role should take.""" return self._todo diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 657737513..1fad4afc2 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -5,9 +5,10 @@ @Author : alexanderwu @File : project_manager.py """ + from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class ProjectManager(Role): @@ -21,23 +22,16 @@ class ProjectManager(Role): constraints (str): Constraints or limitations for the project manager. """ - def __init__( - self, - name: str = "Eve", - profile: str = "Project Manager", - goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " - "dependencies to start with the prerequisite modules", - constraints: str = "use same language as user requirement", - ) -> None: - """ - Initializes the ProjectManager role with given attributes. + name: str = "Eve" + profile: str = "Project Manager" + goal: str = ( + "break down tasks according to PRD/technical design, generate a task list, and analyze task " + "dependencies to start with the prerequisite modules" + ) + constraints: str = "use same language as user requirement" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) - Args: - name (str): Name of the project manager. - profile (str): Role profile. - goal (str): Goal of the project manager. - constraints (str): Constraints or limitations for the project manager. - """ - super().__init__(name, profile, goal, constraints) self._init_actions([WriteTasks]) self._watch([WriteDesign]) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 4439b9b19..39246364e 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -14,6 +14,8 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ + + from metagpt.actions import DebugError, RunCode, WriteTest from metagpt.actions.summarize_code import SummarizeCode from metagpt.config import CONFIG @@ -30,21 +32,23 @@ from metagpt.utils.file_repository import FileRepository class QaEngineer(Role): - def __init__( - self, - name="Edward", - profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5, - ): - super().__init__(name, profile, goal, constraints) - self._init_actions( - [WriteTest] - ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates + name: str = "Edward" + profile: str = "QaEngineer" + goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" + constraints: str = ( + "The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain" + ) + test_round_allowed: int = 5 + test_round: int = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # FIXME: a bit hack here, only init one action to circumvent _think() logic, + # will overwrite _think() in future updates + self._init_actions([WriteTest]) self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 - self.test_round_allowed = test_round_allowed async def _write_test(self, message: Message) -> None: src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) @@ -111,7 +115,8 @@ class QaEngineer(Role): ) run_code_context.code = None run_code_context.test_code = None - recipient = parse_recipient(result.summary) # the recipient might be Engineer or myself + # the recipient might be Engineer or myself + recipient = parse_recipient(result.summary) mappings = {"Engineer": "Alex", "QaEngineer": "Edward"} self.publish_message( Message( @@ -178,4 +183,4 @@ class QaEngineer(Role): async def _observe(self, ignore_memory=False) -> int: # This role has events that trigger and execute themselves based on conditions, and cannot rely on the # content of memory to activate. - return await super(QaEngineer, self)._observe(ignore_memory=True) + return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index d13d43495..fc6afa1fd 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -9,7 +9,7 @@ import asyncio from pydantic import BaseModel -from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize +from metagpt.actions import Action, CollectLinks, ConductResearch, WebBrowseAndSummarize from metagpt.actions.research import get_research_system_text from metagpt.const import RESEARCH_PATH from metagpt.logs import logger @@ -62,24 +62,45 @@ class Researcher(Role): else: topic = msg.content - research_system_text = get_research_system_text(topic, self.language) + research_system_text = self.research_system_text(topic, todo) if isinstance(todo, CollectLinks): links = await todo.run(topic, 4, 4) - ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=todo) + ret = Message( + content="", instruct_content=Report(topic=topic, links=links), role=self.profile, cause_by=todo + ) elif isinstance(todo, WebBrowseAndSummarize): links = instruct_content.links todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items()) summaries = await asyncio.gather(*todos) summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary) - ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=todo) + ret = Message( + content="", instruct_content=Report(topic=topic, summaries=summaries), role=self.profile, cause_by=todo + ) else: summaries = instruct_content.summaries summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries) content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text) - ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=self._rc.todo) + ret = Message( + content="", + instruct_content=Report(topic=topic, content=content), + role=self.profile, + cause_by=self._rc.todo, + ) self._rc.memory.add(ret) return ret + def research_system_text(self, topic, current_task: Action) -> str: + """BACKWARD compatible + This allows sub-class able to define its own system prompt based on topic. + return the previous implementation to have backward compatible + Args: + topic: + language: + + Returns: str + """ + return get_research_system_text(topic, self.language) + async def react(self) -> Message: msg = await super().react() report = msg.instruct_content diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 757c6b8e3..9636a1f30 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -19,20 +19,33 @@ @Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing functionality is to be consolidated into the `Environment` class. """ + from __future__ import annotations from enum import Enum -from typing import Iterable, Set, Type +from pathlib import Path +from typing import Any, Iterable, Set, Type from pydantic import BaseModel, Field -from metagpt.actions import Action, ActionOutput, UserRequirement +from metagpt.actions import Action, ActionOutput +from metagpt.actions.action import action_subclass_registry from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.const import SERDESER_PATH from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message, MessageQueue -from metagpt.utils.common import any_to_name, any_to_str +from metagpt.utils.common import ( + any_to_name, + any_to_str, + import_class, + read_json_file, + role_raise_decorator, + write_json_file, +) from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -75,34 +88,21 @@ class RoleReactMode(str, Enum): return [item.value for item in cls] -class RoleSetting(BaseModel): - """Role properties""" - - name: str - profile: str - goal: str - constraints: str - desc: str - is_human: bool - - def __str__(self): - return f"{self.name}({self.profile})" - - def __repr__(self): - return self.__str__() - - class RoleContext(BaseModel): """Role Runtime Context""" - env: "Environment" = Field(default=None) - msg_buffer: MessageQueue = Field(default_factory=MessageQueue) # Message Buffer with Asynchronous Updates + # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` + env: "Environment" = Field(default=None, exclude=True) + # TODO judge if ser&deser + msg_buffer: MessageQueue = Field( + default_factory=MessageQueue, exclude=True + ) # Message Buffer with Asynchronous Updates memory: Memory = Field(default_factory=Memory) # long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None - todo: Action = Field(default=None) + todo: Action = Field(default=None, exclude=True) watch: set[str] = Field(default_factory=set) - news: list[Type[Message]] = Field(default=[]) + news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used react_mode: RoleReactMode = ( RoleReactMode.REACT ) # see `Role._set_react_mode` for definitions of the following two attributes @@ -127,35 +127,154 @@ class RoleContext(BaseModel): return self.memory.get() -class Role: +role_subclass_registry = {} + + +class Role(BaseModel): """Role/Agent""" - def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False): - self._llm = LLM() if not is_human else HumanProvider() - self._setting = RoleSetting( - name=name, profile=profile, goal=goal, constraints=constraints, desc=desc, is_human=is_human - ) + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + is_human: bool = False + + _llm: BaseGPTAPI = Field(default_factory=LLM) + _role_id: str = "" + _states: list[str] = [] + _actions: list[Action] = [] + _rc: RoleContext = Field(default_factory=RoleContext) + _subscription: tuple[str] = set() + + # builtin variables + recovered: bool = False # to tag if a recovered role + latest_observed_msg: Message = None # record the latest observed message when interrupted + builtin_class_name: str = "" + + _private_attributes = { + "_llm": LLM() if not is_human else HumanProvider(), + "_role_id": _role_id, + "_states": [], + "_actions": [], + "_rc": RoleContext(), + "_subscription": set(), + } + + __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` + + class Config: + arbitrary_types_allowed = True + exclude = ["_llm"] + + def __init__(self, **kwargs: Any): + for index in range(len(kwargs.get("_actions", []))): + current_action = kwargs["_actions"][index] + if isinstance(current_action, dict): + item_class_name = current_action.get("builtin_class_name", None) + for name, subclass in action_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_action = subclass(**current_action) + break + kwargs["_actions"][index] = current_action + + super().__init__(**kwargs) + + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() + self._private_attributes["_role_id"] = str(self._setting) + self._private_attributes["_subscription"] = {any_to_str(self), self.name} if self.name else {any_to_str(self)} + + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + if key == "_rc": + _rc = RoleContext(**kwargs["_rc"]) + object.__setattr__(self, "_rc", _rc) + else: + if key == "_rc": + # # Warning, if use self._private_attributes["_rc"], + # # self._rc will be a shared object between roles, so init one or reset it inside `_reset` + object.__setattr__(self, key, RoleContext()) + else: + object.__setattr__(self, key, self._private_attributes[key]) + self._llm.system_prompt = self._get_prefix() - self._states = [] - self._actions = [] - self._role_id = str(self._setting) - self._rc = RoleContext(watch={any_to_str(UserRequirement)}) - self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} + + # deserialize child classes dynamically for inherited `role` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + if "actions" in kwargs: + self._init_actions(kwargs["actions"]) + + self._watch(kwargs.get("watch") or [UserRequirement]) + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + role_subclass_registry[cls.__name__] = cls def _reset(self): - self._states = [] - self._actions = [] + object.__setattr__(self, "_states", []) + object.__setattr__(self, "_actions", []) + + @property + def _setting(self): + return f"{self.name}({self.profile})" + + def serialize(self, stg_path: Path = None): + stg_path = ( + SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") + if stg_path is None + else stg_path + ) + + role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True}) + role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__}) + role_info_path = stg_path.joinpath("role_info.json") + write_json_file(role_info_path, role_info) + + self._rc.memory.serialize(stg_path) # serialize role's memory alone + + @classmethod + def deserialize(cls, stg_path: Path) -> "Role": + """stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" + role_info_path = stg_path.joinpath("role_info.json") + role_info = read_json_file(role_info_path) + + role_class_str = role_info.pop("role_class") + module_name = role_info.pop("module_name") + role_class = import_class(class_name=role_class_str, module_name=module_name) + + role = role_class(**role_info) # initiate particular Role + role.set_recovered(True) # set True to make a tag + + role_memory = Memory.deserialize(stg_path) + role.set_memory(role_memory) + + return role def _init_action_system_message(self, action: Action): - action.set_prefix(self._get_prefix(), self.profile) + action.set_prefix(self._get_prefix()) + + def set_recovered(self, recovered: bool = False): + self.recovered = recovered + + def set_memory(self, memory: Memory): + self._rc.memory = memory + + def init_actions(self, actions): + self._init_actions(actions) def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): if not isinstance(action, Action): - i = action("", llm=self._llm) + ## 默认初始化 + i = action(name="", llm=self._llm) else: - if self._setting.is_human and not isinstance(action.llm, HumanProvider): + if self.is_human and not isinstance(action.llm, HumanProvider): logger.warning( f"is_human attribute does not take effect, " f"as Role's {str(action)} was initialized using LLM, " @@ -211,7 +330,7 @@ class Role: def _set_state(self, state: int): """Update the current state.""" self._rc.state = state - logger.debug(self._actions) + logger.debug(f"actions={self._actions}, state={state}") self._rc.todo = self._actions[self._rc.state] if state >= 0 else None def set_env(self, env: "Environment"): @@ -221,35 +340,10 @@ class Role: if env: env.set_subscription(self, self._subscription) - @property - def profile(self): - """Get the role description (position)""" - return self._setting.profile - - @property - def name(self): - """Get virtual user name""" - return self._setting.name - @property def subscription(self) -> Set: """The labels for messages to be consumed by the Role object.""" - return self._subscription - - @property - def desc(self): - """Return role `desc`, read only""" - return self._setting.desc - - @property - def goal(self): - """Return role `goal`, read only""" - return self._setting.goal - - @property - def constraints(self): - """Return role `constraints`, read only""" - return self._setting.constraints + return set(self._subscription) @property def action_count(self): @@ -258,16 +352,25 @@ class Role: def _get_prefix(self): """Get the role prefix""" - if self._setting.desc: - return self._setting.desc - return PREFIX_TEMPLATE.format(**self._setting.dict()) + if self.desc: + return self.desc + return PREFIX_TEMPLATE.format( + **{"profile": self.profile, "name": self.name, "goal": self.goal, "constraints": self.constraints} + ) async def _think(self) -> bool: """Consider what to do and decide on the next course of action. Return false if nothing can be done.""" if len(self._actions) == 1: # If there is only one action, then only this one can be performed self._set_state(0) + return True + + if self.recovered and self._rc.state >= 0: + self._set_state(self._rc.state) # action to run from recovered state + self.recovered = False # avoid max_react_loop out of work + return True + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format( history=self._rc.history, @@ -275,10 +378,11 @@ class Role: n_states=len(self._states) - 1, previous_state=self._rc.state, ) - # print(prompt) + next_state = await self._llm.aask(prompt) next_state = extract_state_value_from_output(next_state) logger.debug(f"{prompt=}") + if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self._states)): logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1") next_state = -1 @@ -292,7 +396,7 @@ class Role: async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, @@ -308,15 +412,30 @@ class Role: return msg + def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]: + news = [] + # Warning, remove `id` here to make it work for recover + observed_pure = [msg.dict(exclude={"id": True}) for msg in observed] + existed_pure = [msg.dict(exclude={"id": True}) for msg in existed] + for idx, new in enumerate(observed_pure): + if new["cause_by"] in self._rc.watch and new not in existed_pure: + news.append(observed[idx]) + return news + async def _observe(self, ignore_memory=False) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. news = self._rc.msg_buffer.pop_all() + if self.recovered: + news = [self.latest_observed_msg] if self.latest_observed_msg else [] + else: + self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg + # Store the read messages in your own memory to prevent duplicate processing. old_messages = [] if ignore_memory else self._rc.memory.get() self._rc.memory.add_batch(news) # Filter out messages of interest. - self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages] + self._rc.news = self._find_news(news, old_messages) # Design Rules: # If you need to further categorize Message objects, you can do so using the Message.set_meta function. @@ -347,7 +466,7 @@ class Role: Use llm to select actions in _think dynamically """ actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act + rsp = Message(content="No actions taken yet") # will be overwritten after Role _act while actions_taken < self._rc.max_react_loop: # think await self._think() @@ -361,7 +480,8 @@ class Role: async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): + start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state + for i in range(start_idx, len(self._states)): self._set_state(i) rsp = await self._act() return rsp # return output from the last action @@ -369,7 +489,7 @@ class Role: async def _plan_and_act(self) -> Message: """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" # TODO: to be implemented - return Message("") + return Message(content="") async def react(self) -> Message: """Entry to one of three strategies by which Role reacts to the observed Message""" @@ -386,16 +506,17 @@ class Role: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) + @role_raise_decorator async def run(self, with_message=None): """Observe, and think and act based on the results of the observation""" if with_message: msg = None if isinstance(with_message, str): - msg = Message(with_message) + msg = Message(content=with_message) elif isinstance(with_message, Message): msg = with_message elif isinstance(with_message, list): - msg = Message("\n".join(with_message)) + msg = Message(content="\n".join(with_message)) if not msg.cause_by: msg.cause_by = UserRequirement self.put_message(msg) @@ -430,6 +551,7 @@ class Role: @property def todo(self) -> str: + """AgentStore uses this attribute to display to the user what actions the current role should take.""" if self._actions: return any_to_name(self._actions[0]) return "" diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index d5aac1824..76abf10f3 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -5,30 +5,33 @@ @Author : alexanderwu @File : sales.py """ + +from typing import Optional + from metagpt.actions import SearchAndSummarize from metagpt.roles import Role from metagpt.tools import SearchEngineType class Sales(Role): - def __init__( - self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None, - ): - super().__init__(name, profile, desc=desc) - self._set_store(store) + name: str = "Xiaomei" + profile: str = "Retail sales guide" + desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide" + + store: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._set_store(self.store) def _set_store(self, store): if store: - action = SearchAndSummarize("", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) + action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) else: action = SearchAndSummarize() self._init_actions([action]) diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 5760202ff..e4a672176 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -7,6 +7,9 @@ @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ + +from pydantic import Field + from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.actions.action_node import ActionNode from metagpt.logs import logger @@ -27,15 +30,13 @@ class Searcher(Role): engine (SearchEngineType): The type of search engine to use. """ - def __init__( - self, - name: str = "Alice", - profile: str = "Smart Assistant", - goal: str = "Provide search services for users", - constraints: str = "Answer is rich and complete", - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs, - ) -> None: + name: str = Field(default="Alice") + profile: str = Field(default="Smart Assistant") + goal: str = "Provide search services for users" + constraints: str = "Answer is rich and complete" + engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. @@ -46,12 +47,12 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - super().__init__(name, profile, goal, constraints, **kwargs) - self._init_actions([SearchAndSummarize(engine=engine)]) + super().__init__(**kwargs) + self._init_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" - action = SearchAndSummarize("", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func) + action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func) self._init_actions([action]) async def _act_sp(self) -> Message: @@ -59,7 +60,7 @@ class Searcher(Role): logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.memory.get(k=0)) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 2a514f433..e0be4de61 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -42,17 +42,7 @@ class TutorialAssistant(Role): self.main_title = "" self.total_content = "" self.language = language - - async def _think(self) -> None: - """Determine the next action to be taken by the role.""" - if self._rc.todo is None: - self._set_state(0) - return - - if self._rc.state + 1 < len(self._states): - self._set_state(self._rc.state + 1) - else: - self._rc.todo = None + self._set_react_mode(react_mode="by_order") async def _handle_directory(self, titles: Dict) -> Message: """Handle the directories for the tutorial document. @@ -75,8 +65,6 @@ class TutorialAssistant(Role): for second_dir in first_dir[key]: directory += f" - {second_dir}\n" self._init_actions(actions) - self._rc.todo = None - return Message(content=directory) async def _act(self) -> Message: """Perform an action as determined by the role. @@ -90,7 +78,8 @@ class TutorialAssistant(Role): self.topic = msg.content resp = await todo.run(topic=self.topic) logger.info(resp) - return await self._handle_directory(resp) + await self._handle_directory(resp) + return await super().react() resp = await todo.run(topic=self.topic) logger.info(resp) if self.total_content != "": @@ -98,17 +87,8 @@ class TutorialAssistant(Role): self.total_content += resp return Message(content=resp, role=self.profile) - async def _react(self) -> Message: - """Execute the assistant's think and actions. - - Returns: - A message containing the final result of the assistant's actions. - """ - while True: - await self._think() - if self._rc.todo is None: - break - msg = await self._act() + async def react(self) -> Message: + msg = await super().react() root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8")) return msg diff --git a/metagpt/schema.py b/metagpt/schema.py index b17565979..60b9a6998 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -12,16 +12,18 @@ between actions. 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ + from __future__ import annotations import asyncio import json import os.path import uuid +from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar +from typing import Any, Dict, List, Optional, Set, Type, TypedDict, TypeVar from pydantic import BaseModel, Field @@ -35,8 +37,13 @@ from metagpt.const import ( TASK_FILE_REPO, ) from metagpt.logs import logger -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_str, any_to_str_set, import_class from metagpt.utils.exceptions import handle_exception +from metagpt.utils.serialize import ( + actionoutout_schema_to_mapping, + actionoutput_mapping_to_str, + actionoutput_str_to_mapping, +) class RawMessage(TypedDict): @@ -97,41 +104,31 @@ class Message(BaseModel): id: str # According to Section 2.2.3.1.1 of RFC 135 content: str - instruct_content: BaseModel = Field(default=None) + instruct_content: BaseModel = None role: str = "user" # system / user / assistant cause_by: str = "" sent_from: str = "" send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) - def __init__( - self, - content, - instruct_content=None, - role="user", - cause_by="", - sent_from="", - send_to=MESSAGE_ROUTE_TO_ALL, - **kwargs, - ): - """ - Parameters not listed below will be stored as meta info, including custom parameters. - :param content: Message content. - :param instruct_content: Message content struct. - :param cause_by: Message producer - :param sent_from: Message route info tells who sent this message. - :param send_to: Specifies the target recipient or consumer for message delivery in the environment. - :param role: Message meta info tells who sent this message. - """ - super().__init__( - id=uuid.uuid4().hex, - content=content, - instruct_content=instruct_content, - role=role, - cause_by=any_to_str(cause_by), - sent_from=any_to_str(sent_from), - send_to=any_to_str_set(send_to), - **kwargs, + def __init__(self, content: str = "", **kwargs): + ic = kwargs.get("instruct_content", None) + if ic and not isinstance(ic, BaseModel) and "class" in ic: + # compatible with custom-defined ActionOutput + mapping = actionoutput_str_to_mapping(ic["mapping"]) + + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + kwargs["instruct_content"] = ic_new + + kwargs["id"] = kwargs.get("id", uuid.uuid4().hex) + kwargs["content"] = kwargs.get("content", content) + kwargs["cause_by"] = any_to_str( + kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement")) ) + kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", "")) + kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL})) + super(Message, self).__init__(**kwargs) def __setattr__(self, key, val): """Override `@property.setter`, convert non-string parameters into string parameters.""" @@ -145,8 +142,26 @@ class Message(BaseModel): new_val = val super().__setattr__(key, new_val) + def dict(self, *args, **kwargs) -> "DictStrAny": + """overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Message, self).dict(*args, **kwargs) + ic = self.instruct_content + if ic: + # compatible with custom-defined ActionOutput + schema = ic.schema() + # `Documents` contain definitions + if "definitions" not in schema: + # TODO refine with nested BaseModel + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + return obj_dict + def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) + if self.instruct_content: + return f"{self.role}: {self.instruct_content.dict()}" return f"{self.role}: {self.content}" def __repr__(self): @@ -164,6 +179,7 @@ class Message(BaseModel): @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" + try: m = json.loads(val) id = m.get("id") @@ -205,11 +221,22 @@ class AIMessage(Message): super().__init__(content=content, role="assistant") -class MessageQueue: +class MessageQueue(BaseModel): """Message queue which supports asynchronous updates.""" - def __init__(self): - self._queue = Queue() + _queue: Queue = Field(default_factory=Queue) + + _private_attributes = {"_queue": Queue()} + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **kwargs: Any): + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + else: + object.__setattr__(self, key, Queue()) def pop(self) -> Message | None: """Pop one message from the queue.""" @@ -257,16 +284,16 @@ class MessageQueue: return json.dumps(lst) @staticmethod - def load(i) -> "MessageQueue": + def load(data) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" queue = MessageQueue() try: - lst = json.loads(i) + lst = json.loads(data) for i in lst: msg = Message(**i) queue.push(msg) except JSONDecodeError as e: - logger.warning(f"JSON load failed: {i}, error:{e}") + logger.warning(f"JSON load failed: {data}, error:{e}") return queue @@ -275,7 +302,7 @@ class MessageQueue: T = TypeVar("T", bound="BaseModel") -class BaseContext(BaseModel): +class BaseContext(BaseModel, ABC): @classmethod @handle_exception def loads(cls: Type[T], val: str) -> Optional[T]: diff --git a/metagpt/startup.py b/metagpt/startup.py index e886ad2a4..767a19a9d 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -7,7 +7,7 @@ import typer from metagpt.config import CONFIG -app = typer.Typer() +app = typer.Typer(add_completion=False) @app.command() @@ -22,13 +22,17 @@ def startup( inc: bool = typer.Option(default=False, help="Incremental mode. Use it to coop with existing repo."), project_path: str = typer.Option( default="", - help="Specify the directory path of the old version project to fulfill the " "incremental requirements.", + help="Specify the directory path of the old version project to fulfill the incremental requirements.", + ), + reqa_file: str = typer.Option( + default="", help="Specify the source file name for rewriting the quality assurance code." ), - reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), max_auto_summarize_code: int = typer.Option( default=0, - help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited. This parameter is used for debugging the workflow.", + help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " + "unlimited. This parameter is used for debugging the workflow.", ), + recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage"), ): """Run a startup. Be a boss.""" from metagpt.roles import ( @@ -40,30 +44,31 @@ def startup( ) from metagpt.team import Team - # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. - CONFIG.project_path = project_path - if project_path: - inc = True - project_name = project_name or Path(project_path).name - CONFIG.project_name = project_name - CONFIG.inc = inc - CONFIG.reqa_file = reqa_file - CONFIG.max_auto_summarize_code = max_auto_summarize_code + CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) - company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) + if not recover_path: + company = Team() + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) - if implement or code_review: - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + if implement or code_review: + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - if run_tests: - company.hire([QaEngineer()]) + if run_tests: + company.hire([QaEngineer()]) + else: + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.deserialize(stg_path=stg_path) + idea = company.idea # use original idea company.invest(investment) company.run_project(idea) diff --git a/metagpt/subscription.py b/metagpt/subscription.py index 0d2b30821..607cbdb8d 100644 --- a/metagpt/subscription.py +++ b/metagpt/subscription.py @@ -19,7 +19,7 @@ class SubscriptionRunner(BaseModel): >>> async def trigger(): ... while True: - ... yield Message("the latest news about OpenAI") + ... yield Message(content="the latest news about OpenAI") ... await asyncio.sleep(3600 * 24) >>> async def callback(msg: Message): diff --git a/metagpt/team.py b/metagpt/team.py index d613a04ef..ed370fd16 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -7,24 +7,31 @@ @Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in Section 2.2.3.3 of RFC 135. """ + import warnings +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import UserRequirement from metagpt.config import CONFIG -from metagpt.const import MESSAGE_ROUTE_TO_ALL +from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import NoMoneyException +from metagpt.utils.common import ( + NoMoneyException, + read_json_file, + serialize_decorator, + write_json_file, +) class Team(BaseModel): """ - Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a platform for instant messaging, - dedicated to perform any multi-agent activity, such as collaboratively writing executable code. + Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a env for instant messaging, + dedicated to env any multi-agent activity, such as collaboratively writing executable code. """ env: Environment = Field(default_factory=Environment) @@ -34,6 +41,38 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + + team_info_path = stg_path.joinpath("team_info.json") + write_json_file(team_info_path, self.dict(exclude={"env": True})) + + self.env.serialize(stg_path.joinpath("environment")) # save environment alone + + @classmethod + def recover(cls, stg_path: Path) -> "Team": + return cls.deserialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Team": + """stg_path = ./storage/team""" + # recover team_info + team_info_path = stg_path.joinpath("team_info.json") + if not team_info_path.exists(): + raise FileNotFoundError( + "recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project." + ) + + team_info: dict = read_json_file(team_info_path) + + # recover environment + environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) + team_info.update({"env": environment}) + + team = Team(**team_info) + return team + def hire(self, roles: list[Role]): """Hire roles to cooperate""" self.env.add_roles(roles) @@ -77,6 +116,7 @@ class Team(BaseModel): def _save(self): logger.info(self.json(ensure_ascii=False)) + @serialize_decorator async def run(self, n_round=3, auto_archive=True): """Run company until target round or no money""" while n_round > 0: @@ -84,7 +124,7 @@ class Team(BaseModel): n_round -= 1 logger.debug(f"max {n_round=} left.") self._check_balance() + await self.env.run() - if auto_archive and CONFIG.git_repo: - CONFIG.git_repo.archive() + self.env.archive(auto_archive) return self.env.history diff --git a/metagpt/tools/azure_tts.py b/metagpt/tools/azure_tts.py index 6864faf10..8fdb10c13 100644 --- a/metagpt/tools/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -1,10 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -@Time : 2023/8/17 -@Author : mashenquan +@Time : 2023/6/9 22:22 +@Author : Leo Xiao @File : azure_tts.py -@Desc : azure TTS OAS3 api, which provides text-to-speech functionality +@Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality """ import asyncio import base64 diff --git a/metagpt/tools/code_interpreter.py b/metagpt/tools/code_interpreter.py index 1cba005fa..9575d6c13 100644 --- a/metagpt/tools/code_interpreter.py +++ b/metagpt/tools/code_interpreter.py @@ -46,7 +46,6 @@ class OpenCodeInterpreter(object): interpreter.auto_run = auto_run interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo" interpreter.api_key = CONFIG.openai_api_key - # interpreter.api_base = CONFIG.openai_api_base self.interpreter = interpreter def chat(self, query: str, reset: bool = True): diff --git a/metagpt/tools/moderation.py b/metagpt/tools/moderation.py index c56a6afc4..5532e4f66 100644 --- a/metagpt/tools/moderation.py +++ b/metagpt/tools/moderation.py @@ -5,6 +5,7 @@ @Author : zhanglei @File : moderation.py """ +import asyncio from typing import Union from metagpt.llm import LLM @@ -14,16 +15,6 @@ class Moderation: def __init__(self): self.llm = LLM() - def moderation(self, content: Union[str, list[str]]): - resp = [] - if content: - moderation_results = self.llm.moderation(content=content) - results = moderation_results.results - for item in results: - resp.append(item.flagged) - - return resp - async def amoderation(self, content: Union[str, list[str]]): resp = [] if content: @@ -35,6 +26,13 @@ class Moderation: return resp -if __name__ == "__main__": +async def main(): moderation = Moderation() - print(moderation.moderation(content=["I will kill you", "The weather is really nice today", "I want to hit you"])) + rsp = await moderation.amoderation( + content=["I will kill you", "The weather is really nice today", "I want to hit you"] + ) + print(rsp) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 51d26e551..628c8dea2 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -111,6 +111,8 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): options.add_argument("--headless") options.add_argument("--enable-javascript") if browser_type == "chrome": + options.add_argument("--disable-gpu") # This flag can help avoid renderer issue + options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--no-sandbox") for i in args: options.add_argument(i) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index a5d2100cc..a1cb71c6f 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -13,15 +13,21 @@ from __future__ import annotations import ast import contextlib +import importlib import inspect +import json import os import platform import re +import sys +import traceback +import typing from pathlib import Path -from typing import Callable, List, Tuple, Union +from typing import Any, Callable, List, Tuple, Union, get_args, get_origin import aiofiles import loguru +from pydantic.json import pydantic_encoder from tenacity import RetryCallState, _utils from metagpt.config import CONFIG @@ -43,6 +49,12 @@ def check_cmd_exists(command) -> int: return result +def require_python_version(req_version: tuple[int]) -> bool: + if not (2 <= len(req_version) <= 3): + raise ValueError("req_version should be (3, 9) or (3, 10, 13)") + return True if sys.version_info > req_version else False + + class OutputParser: @classmethod def parse_blocks(cls, text: str): @@ -130,8 +142,32 @@ class OutputParser: parsed_data[block] = content return parsed_data + @staticmethod + def extract_content(text, tag="CONTENT"): + # Use regular expression to extract content between [CONTENT] and [/CONTENT] + extracted_content = re.search(rf"\[{tag}\](.*?)\[/{tag}\]", text, re.DOTALL) + + if extracted_content: + return extracted_content.group(1).strip() + else: + return "No content found between [CONTENT] and [/CONTENT] tags." + + @staticmethod + def is_supported_list_type(i): + origin = get_origin(i) + if origin is not List: + return False + + args = get_args(i) + if args == (str,) or args == (Tuple[str, str],) or args == (List[str],): + return True + + return False + @classmethod def parse_data_with_mapping(cls, data, mapping): + if "[CONTENT]" in data: + data = cls.extract_content(text=data) block_dict = cls.parse_blocks(data) parsed_data = {} for block, content in block_dict.items(): @@ -198,7 +234,7 @@ class OutputParser: result = ast.literal_eval(structure_text) # Ensure the result matches the specified data type - if isinstance(result, list) or isinstance(result, dict): + if isinstance(result, (list, dict)): return result raise ValueError(f"The extracted structure is not a {data_type}.") @@ -437,6 +473,81 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.C return log_it +def read_json_file(json_file: str, encoding=None) -> list[Any]: + if not Path(json_file).exists(): + raise FileNotFoundError(f"json_file: {json_file} not exist, return []") + + with open(json_file, "r", encoding=encoding) as fin: + try: + data = json.load(fin) + except Exception: + raise ValueError(f"read json file: {json_file} failed") + return data + + +def write_json_file(json_file: str, data: list, encoding=None): + folder_path = Path(json_file).parent + if not folder_path.exists(): + folder_path.mkdir(parents=True, exist_ok=True) + + with open(json_file, "w", encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder) + + +def import_class(class_name: str, module_name: str) -> type: + module = importlib.import_module(module_name) + a_class = getattr(module, class_name) + return a_class + + +def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: + a_class = import_class(class_name, module_name) + class_inst = a_class(*args, **kwargs) + return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + result = await func(self, *args, **kwargs) + return result + except KeyboardInterrupt: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + except Exception: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self.latest_observed_msg: + self._rc.memory.delete(self.latest_observed_msg) + # raise again to make it captured outside + raise Exception(format_trackback_info(limit=None)) + except Exception: + if self.latest_observed_msg: + logger.warning( + "There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory." + ) + # remove role newest observed msg to make it observed again + self._rc.memory.delete(self.latest_observed_msg) + # raise again to make it captured outside + raise Exception(format_trackback_info(limit=None)) + + return wrapper + + @handle_exception async def aread(file_path: str) -> str: """Read file asynchronously.""" diff --git a/metagpt/utils/exceptions.py b/metagpt/utils/exceptions.py index b4b5aa590..70ed45910 100644 --- a/metagpt/utils/exceptions.py +++ b/metagpt/utils/exceptions.py @@ -21,6 +21,7 @@ def handle_exception( _func: Callable[..., ReturnType] = None, *, exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, + exception_msg: str = "", default_return: Any = None, ) -> Callable[..., ReturnType]: """handle exception, return default value""" @@ -32,8 +33,9 @@ def handle_exception( return await func(*args, **kwargs) except exception_type as e: logger.opt(depth=1).error( - f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, " - f"stack: {traceback.format_exc()}" + f"{e}: {exception_msg}, " + f"\nCalling {func.__name__} with args: {args}, kwargs: {kwargs} " + f"\nStack: {traceback.format_exc()}" ) return default_return diff --git a/metagpt/utils/get_template.py b/metagpt/utils/get_template.py index 86c1915f7..7e05e5d5e 100644 --- a/metagpt/utils/get_template.py +++ b/metagpt/utils/get_template.py @@ -8,10 +8,10 @@ from metagpt.config import CONFIG -def get_template(templates, format=CONFIG.prompt_format): - selected_templates = templates.get(format) +def get_template(templates, schema=CONFIG.prompt_schema): + selected_templates = templates.get(schema) if selected_templates is None: - raise ValueError(f"Can't find {format} in passed in templates") + raise ValueError(f"Can't find {schema} in passed in templates") # Extract the selected templates prompt_template = selected_templates["PROMPT_TEMPLATE"] diff --git a/metagpt/utils/make_sk_kernel.py b/metagpt/utils/make_sk_kernel.py index 5e919abeb..83b4005ec 100644 --- a/metagpt/utils/make_sk_kernel.py +++ b/metagpt/utils/make_sk_kernel.py @@ -21,14 +21,12 @@ def make_sk_kernel(): if CONFIG.openai_api_type == "azure": kernel.add_chat_service( "chat_completion", - AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_api_base, CONFIG.openai_api_key), + AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_base_url, CONFIG.openai_api_key), ) else: kernel.add_chat_service( "chat_completion", - OpenAIChatCompletion( - CONFIG.openai_api_model, CONFIG.openai_api_key, org_id=None, endpoint=CONFIG.openai_api_base - ), + OpenAIChatCompletion(CONFIG.openai_api_model, CONFIG.openai_api_key), ) return kernel diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 4aafd8e66..67ad4e963 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -253,7 +253,7 @@ def retry_parse_json_text(output: str) -> Union[list, dict]: if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. it's a two-layer retry cycle """ - logger.debug(f"output to json decode:\n{output}") + # logger.debug(f"output to json decode:\n{output}") # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break parsed_data = CustomDecoder(strict=False).decode(output) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 124176fcb..3939b1306 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,13 +4,11 @@ import copy import pickle -from typing import Dict, List -from metagpt.actions.action_output import ActionOutput -from metagpt.schema import Message +from metagpt.utils.common import import_class -def actionoutout_schema_to_mapping(schema: Dict) -> Dict: +def actionoutout_schema_to_mapping(schema: dict) -> dict: """ directly traverse the `properties` in the first level. schema structure likes @@ -35,14 +33,31 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict: if property["type"] == "string": mapping[field] = (str, ...) elif property["type"] == "array" and property["items"]["type"] == "string": - mapping[field] = (List[str], ...) + mapping[field] = (list[str], ...) elif property["type"] == "array" and property["items"]["type"] == "array": - # here only consider the `List[List[str]]` situation - mapping[field] = (List[List[str]], ...) + # here only consider the `list[list[str]]` situation + mapping[field] = (list[list[str]], ...) return mapping -def serialize_message(message: Message): +def actionoutput_mapping_to_str(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + new_mapping[key] = str(value) + return new_mapping + + +def actionoutput_str_to_mapping(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + if value == "(, Ellipsis)": + new_mapping[key] = (str, ...) + else: + new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)` + return new_mapping + + +def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content if ic: @@ -56,11 +71,12 @@ def serialize_message(message: Message): return msg_ser -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index ebfb85de7..94b8d76d2 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -16,6 +17,8 @@ TOKEN_COSTS = { "gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002}, "gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004}, "gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004}, + "gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002}, + "gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004}, "gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002}, "gpt-4-0314": {"prompt": 0.03, "completion": 0.06}, "gpt-4": {"prompt": 0.03, "completion": 0.06}, @@ -25,6 +28,7 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } @@ -34,6 +38,8 @@ TOKEN_MAX = { "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k": 16384, "gpt-3.5-turbo-16k-0613": 16384, + "gpt-35-turbo": 4096, + "gpt-35-turbo-16k": 16384, "gpt-3.5-turbo-1106": 16384, "gpt-4-0314": 8192, "gpt-4": 8192, @@ -43,6 +49,7 @@ TOKEN_MAX = { "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, + "gemini-pro": 32768, } @@ -56,6 +63,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): if model in { "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", + "gpt-35-turbo", + "gpt-35-turbo-16k", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-1106", "gpt-4-0314", diff --git a/requirements.txt b/requirements.txt index c4e674569..d221dc3c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ channels==4.0.0 # docx==0.2.4 #faiss==1.5.3 faiss_cpu==1.7.4 -# fire==0.4.0 +fire==0.4.0 typer # godot==0.1.1 # google_api_python_client==2.93.0 @@ -15,7 +15,7 @@ langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 -openai>=1.3.6 +openai==1.6.0 openpyxl beautifulsoup4==4.12.2 pandas==2.0.3 @@ -28,7 +28,7 @@ PyYAML==6.0.1 # sentence_transformers==2.2.2 setuptools==65.6.3 tenacity==8.2.2 -tiktoken==0.4.0 +tiktoken==0.5.2 tqdm==4.64.0 #unstructured[local-inference] # playwright @@ -36,12 +36,12 @@ tqdm==4.64.0 # webdriver_manager<3.9 anthropic==0.3.6 typing-inspect==0.8.0 -typing_extensions==4.5.0 aiofiles +typing_extensions==4.7.0 libcst==1.0.1 qdrant-client==1.4.0 pytest-mock==3.11.1 -open-interpreter==0.1.7; python_version>"3.9" +# open-interpreter==0.1.7; python_version>"3.9" ta==0.10.2 semantic-kernel==0.4.0.dev0 wrapt==1.15.0 @@ -55,7 +55,8 @@ gitpython==3.1.40 zhipuai==1.0.7 socksio~=1.0.0 gitignore-parser==0.1.9 -connexion[swagger-ui] +# connexion[swagger-ui] websockets~=12.0 networkx~=3.2.1 -pylint~=3.0.3 \ No newline at end of file +pylint~=3.0.3 +google-generativeai==0.3.1 diff --git a/ruff.toml b/ruff.toml index 7835865e0..21de5ee14 100644 --- a/ruff.toml +++ b/ruff.toml @@ -31,7 +31,7 @@ exclude = [ ] # Same as Black. -line-length = 119 +line-length = 120 # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" diff --git a/setup.py b/setup.py index 8285a4b67..8ef2a6946 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ with open(path.join(here, "requirements.txt"), encoding="utf-8") as f: setup( name="metagpt", version="0.5.2", - description="The Multi-Role Meta Programming Framework", + description="The Multi-Agent Framework", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/geekan/MetaGPT", diff --git a/tests/metagpt/actions/test_action_output.py b/tests/metagpt/actions/test_action_output.py index ef8e239bd..f1765cb03 100644 --- a/tests/metagpt/actions/test_action_output.py +++ b/tests/metagpt/actions/test_action_output.py @@ -7,7 +7,7 @@ """ from typing import List, Tuple -from metagpt.actions import ActionOutput +from metagpt.actions.action_node import ActionNode t_dict = { "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', @@ -37,12 +37,12 @@ WRITE_TASKS_OUTPUT_MAPPING = { def test_create_model_class(): - test_class = ActionOutput.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) + test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) assert test_class.__name__ == "test_class" def test_create_model_class_with_mapping(): - t = ActionOutput.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) + t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) t1 = t(**t_dict) value = t1.dict()["Task list"] assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] diff --git a/tests/metagpt/actions/test_azure_tts.py b/tests/metagpt/actions/test_azure_tts.py new file mode 100644 index 000000000..9995e9691 --- /dev/null +++ b/tests/metagpt/actions/test_azure_tts.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/1 22:50 +@Author : alexanderwu +@File : test_azure_tts.py +""" +from metagpt.tools.azure_tts import AzureTTS + + +def test_azure_tts(): + azure_tts = AzureTTS() + azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav") + + # 运行需要先配置 SUBSCRIPTION_KEY + # TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有 diff --git a/tests/metagpt/actions/test_detail_mining.py b/tests/metagpt/actions/test_detail_mining.py index 891dca6ca..a178ec840 100644 --- a/tests/metagpt/actions/test_detail_mining.py +++ b/tests/metagpt/actions/test_detail_mining.py @@ -3,21 +3,27 @@ """ @Time : 2023/9/13 00:26 @Author : fisherdeng -@File : test_detail_mining.py +@File : test_generate_questions.py """ import pytest -from metagpt.actions.detail_mining import DetailMining +from metagpt.actions.generate_questions import GenerateQuestions from metagpt.logs import logger +context = """ +## topic +如何做一个生日蛋糕 + +## record +我认为应该先准备好材料,然后再开始做蛋糕。 +""" + @pytest.mark.asyncio -async def test_detail_mining(): - topic = "如何做一个生日蛋糕" - record = "我认为应该先准备好材料,然后再开始做蛋糕。" - detail_mining = DetailMining("detail_mining") - rsp = await detail_mining.run(topic=topic, record=record) +async def test_generate_questions(): + detail_mining = GenerateQuestions() + rsp = await detail_mining.run(context) logger.info(f"{rsp.content=}") - assert "##OUTPUT" in rsp.content - assert "蛋糕" in rsp.content + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/actions/test_prepare_interview.py b/tests/metagpt/actions/test_prepare_interview.py new file mode 100644 index 000000000..7c32882e0 --- /dev/null +++ b/tests/metagpt/actions/test_prepare_interview.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 00:26 +@Author : fisherdeng +@File : test_detail_mining.py +""" +import pytest + +from metagpt.actions.prepare_interview import PrepareInterview +from metagpt.logs import logger + + +@pytest.mark.asyncio +async def test_prepare_interview(): + action = PrepareInterview() + rsp = await action.run("I just graduated and hope to find a job as a Python engineer") + logger.info(f"{rsp.content=}") + + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/actions/test_write_review.py b/tests/metagpt/actions/test_write_review.py new file mode 100644 index 000000000..2d188b720 --- /dev/null +++ b/tests/metagpt/actions/test_write_review.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/20 15:01 +@Author : alexanderwu +@File : test_write_review.py +""" +import pytest + +from metagpt.actions.write_review import WriteReview + +CONTEXT = """ +{ + "Language": "zh_cn", + "Programming Language": "Python", + "Original Requirements": "写一个简单的2048", + "Project Name": "game_2048", + "Product Goals": [ + "创建一个引人入胜的用户体验", + "确保高性能", + "提供可定制的功能" + ], + "User Stories": [ + "作为用户,我希望能够选择不同的难度级别", + "作为玩家,我希望在每局游戏结束后能看到我的得分" + ], + "Competitive Analysis": [ + "Python Snake Game: 界面简单,缺乏高级功能" + ], + "Competitive Quadrant Chart": "quadrantChart\n title \"Reach and engagement of campaigns\"\n x-axis \"Low Reach\" --> \"High Reach\"\n y-axis \"Low Engagement\" --> \"High Engagement\"\n quadrant-1 \"我们应该扩展\"\n quadrant-2 \"需要推广\"\n quadrant-3 \"重新评估\"\n quadrant-4 \"可能需要改进\"\n \"Campaign A\": [0.3, 0.6]\n \"Campaign B\": [0.45, 0.23]\n \"Campaign C\": [0.57, 0.69]\n \"Campaign D\": [0.78, 0.34]\n \"Campaign E\": [0.40, 0.34]\n \"Campaign F\": [0.35, 0.78]\n \"Our Target Product\": [0.5, 0.6]", + "Requirement Analysis": "产品应该用户友好。", + "Requirement Pool": [ + [ + "P0", + "主要代码..." + ], + [ + "P0", + "游戏算法..." + ] + ], + "UI Design draft": "基本功能描述,简单的风格和布局。", + "Anything UNCLEAR": "..." +} +""" + + +@pytest.mark.asyncio +async def test_write_review(): + write_review = WriteReview() + review = await write_review.run(CONTEXT) + assert review.instruct_content + assert review.get("LGTM") in ["LGTM", "LBTM"] diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index c67ca689f..7b74eb512 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -8,7 +8,7 @@ from typing import List from metagpt.actions import UserRequirement, WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message @@ -42,7 +42,7 @@ def test_idea_message(): def test_actionout_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) role_id = "UTUser2(Architect)" content = "The user has requested the creation of a command-line interface (CLI) snake game" diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..229d9b9a7 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +from abc import ABC +from dataclasses import dataclass + +import pytest + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + +messages = [{"role": "user", "content": "who are you"}] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 2b0af37b5..332d554cf 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock + import pytest from metagpt.provider.openai_api import OpenAIGPTAPI @@ -78,3 +80,70 @@ def test_ask_code_list_str(): assert "language" in rsp assert "code" in rsp assert len(rsp["code"]) > 0 + + +class TestOpenAI: + @pytest.fixture + def config(self): + return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other") + + @pytest.fixture + def config_azure(self): + return Mock( + openai_api_key="test_key", + openai_api_version="test_version", + openai_base_url="test_url", + openai_proxy=None, + openai_api_type="azure", + ) + + @pytest.fixture + def config_proxy(self): + return Mock( + openai_api_key="test_key", + openai_base_url="test_url", + openai_proxy="http://proxy.com", + openai_api_type="other", + ) + + @pytest.fixture + def config_azure_proxy(self): + return Mock( + openai_api_key="test_key", + openai_api_version="test_version", + openai_base_url="test_url", + openai_proxy="http://proxy.com", + openai_api_type="azure", + ) + + def test_make_client_kwargs_without_proxy(self, config): + instance = OpenAIGPTAPI() + instance.config = config + kwargs, async_kwargs = instance._make_client_kwargs() + assert kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert "http_client" not in kwargs + assert "http_client" not in async_kwargs + + def test_make_client_kwargs_without_proxy_azure(self, config_azure): + instance = OpenAIGPTAPI() + instance.config = config_azure + kwargs, async_kwargs = instance._make_client_kwargs() + assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert "http_client" not in kwargs + assert "http_client" not in async_kwargs + + def test_make_client_kwargs_with_proxy(self, config_proxy): + instance = OpenAIGPTAPI() + instance.config = config_proxy + kwargs, async_kwargs = instance._make_client_kwargs() + assert "http_client" in kwargs + assert "http_client" in async_kwargs + + def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): + instance = OpenAIGPTAPI() + instance.config = config_azure_proxy + kwargs, async_kwargs = instance._make_client_kwargs() + assert "http_client" in kwargs + assert "http_client" in async_kwargs diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py new file mode 100644 index 000000000..72cd84a9a --- /dev/null +++ b/tests/metagpt/roles/test_role.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of Role + +from metagpt.roles.role import Role + + +def test_role_desc(): + role = Role(profile="Sales", desc="Best Seller") + assert role.profile == "Sales" + assert role._setting.desc == "Best Seller" diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py index 5c2d64a58..51b346821 100644 --- a/tests/metagpt/roles/ui_role.py +++ b/tests/metagpt/roles/ui_role.py @@ -8,6 +8,7 @@ from functools import wraps from importlib import import_module from metagpt.actions import Action, ActionOutput, WritePRD +from metagpt.actions.action_node import ActionNode from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.roles import Role @@ -15,44 +16,38 @@ from metagpt.schema import Message from metagpt.tools.sd_engine import SDEngine PROMPT_TEMPLATE = """ -# Context {context} -## Format example -{format_example} ------ -Role: You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. -Requirements: Based on the context, fill in the following missing information, provide detailed HTML and CSS code -Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the code and triple quote. - -## UI Design Description:Provide as Plain text, place the design objective here -## Selected Elements:Provide as Plain text, up to 5 specified elements, clear and simple -## HTML Layout:Provide as Plain text, use standard HTML code -## CSS Styles (styles.css):Provide as Plain text,use standard css code -## Anything UNCLEAR:Provide as Plain text. Try to clarify it. - +## Role +You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. """ -FORMAT_EXAMPLE = """ +UI_DESIGN_DESC = ActionNode( + key="UI Design Desc", + expected_type=str, + instruction="place the design objective here", + example="Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements" + " commonly found in snake games", +) -## UI Design Description -```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ``` +SELECTED_ELEMENTS = ActionNode( + key="Selected Elements", + expected_type=list[str], + instruction="up to 5 specified elements, clear and simple", + example=[ + "Game Grid: The game grid is a rectangular...", + "Snake: The player controls a snake that moves across the grid...", + "Food: Food items (often represented as small objects or differently colored blocks)", + "Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.", + "Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.", + ], +) -## Selected Elements - -Game Grid: The game grid is a rectangular... - -Snake: The player controls a snake that moves across the grid... - -Food: Food items (often represented as small objects or differently colored blocks) - -Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score. - -Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game. - - -## HTML Layout - +HTML_LAYOUT = ActionNode( + key="HTML Layout", + expected_type=str, + instruction="use standard HTML code", + example=""" @@ -69,9 +64,14 @@ Game Over: The game ends when the snake collides with itself or an obstacle. At +""", +) -## CSS Styles (styles.css) -body { +CSS_STYLES = ActionNode( + key="CSS Styles", + expected_type=str, + instruction="use standard css code", + example="""body { display: flex; justify-content: center; align-items: center; @@ -119,19 +119,25 @@ body { color: #ff0000; display: none; } +""", +) -## Anything UNCLEAR -There are no unclear points. +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any aspects of the project that are unclear and try to clarify them.", + example="...", +) -""" +NODES = [ + UI_DESIGN_DESC, + SELECTED_ELEMENTS, + HTML_LAYOUT, + CSS_STYLES, + ANYTHING_UNCLEAR, +] -OUTPUT_MAPPING = { - "UI Design Description": (str, ...), - "Selected Elements": (str, ...), - "HTML Layout": (str, ...), - "CSS Styles (styles.css)": (str, ...), - "Anything UNCLEAR": (str, ...), -} +UI_DESIGN_NODE = ActionNode.from_children("UI_DESIGN", NODES) def load_engine(func): @@ -221,10 +227,8 @@ class UIDesign(Action): css_file_path = save_dir / "ui_design.css" html_file_path = save_dir / "ui_design.html" - with open(css_file_path, "w") as css_file: - css_file.write(css_content) - with open(html_file_path, "w") as html_file: - html_file.write(html_content) + css_file_path.write_text(css_content) + html_file_path.write_text(html_content) async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput: """Run the UI Design action.""" @@ -232,9 +236,9 @@ class UIDesign(Action): context = requirements[-1].content ui_design_draft = self.parse_requirement(context=context) # todo: parse requirements str - prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE) + prompt = PROMPT_TEMPLATE.format(context=ui_design_draft) logger.info(prompt) - ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING) + ui_describe = await UI_DESIGN_NODE.fill(prompt) logger.info(ui_describe.content) logger.info(ui_describe.instruct_content) css = self.parse_css_code(context=ui_describe.content) diff --git a/tests/metagpt/serialize_deserialize/__init__.py b/tests/metagpt/serialize_deserialize/__init__.py new file mode 100644 index 000000000..78f454fb5 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py new file mode 100644 index 000000000..14d558c13 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action +from metagpt.llm import LLM + + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" not in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + + new_action = Action(**serialized_data) + + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py new file mode 100644 index 000000000..b92eba8a1 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:04 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions.action import Action +from metagpt.roles.architect import Architect + + +def test_architect_serialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_architect_deserialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + new_role = Architect(**ser_role_dict) + # new_role = Architect.deserialize(ser_role_dict) + assert new_role.name == "Bob" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py new file mode 100644 index 000000000..096c1dd68 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import shutil + +from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.actions.project_management import WriteTasks +from metagpt.environment import Environment +from metagpt.roles.project_manager import ProjectManager +from metagpt.schema import Message +from metagpt.utils.common import any_to_str +from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + ActionOK, + RoleC, + serdeser_path, +) + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert len(new_env.history) == 25 + + +def test_environment_serdeser(): + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionNode.create_model_class("prd", out_mapping) + + message = Message( + content="prd", instruct_content=ic_obj(**out_data), role="product manager", cause_by=any_to_str(UserRequirement) + ) + + environment = Environment() + role_c = RoleC() + environment.add_role(role_c) + environment.publish_message(message) + + ser_data = environment.dict() + assert ser_data["roles"]["Role C"]["name"] == "RoleC" + + new_env: Environment = Environment(**ser_data) + assert len(new_env.roles) == 1 + + assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states + assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions + assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK) + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK + + +def test_environment_serdeser_v2(): + environment = Environment() + pm = ProjectManager() + environment.add_role(pm) + + ser_data = environment.dict() + + new_env: Environment = Environment(**ser_data) + role = new_env.get_role(pm.profile) + assert isinstance(role, ProjectManager) + assert isinstance(role._actions[0], WriteTasks) + assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks) + + +def test_environment_serdeser_save(): + environment = Environment() + role_c = RoleC() + + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + stg_path = serdeser_path.joinpath("team", "environment") + environment.add_role(role_c) + environment.serialize(stg_path) + + new_env: Environment = Environment.deserialize(stg_path) + assert len(new_env.roles) == 1 + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py new file mode 100644 index 000000000..5a40f5c3b --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of memory + +from pydantic import BaseModel + +from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.actions.design_api import WriteDesign +from metagpt.memory.memory import Memory +from metagpt.schema import Message +from metagpt.utils.common import any_to_str +from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path + + +def test_memory_serdeser(): + msg1 = Message(role="Boss", content="write a snake game", cause_by=UserRequirement) + + out_mapping = {"field2": (list[str], ...)} + out_data = {"field2": ["field2 value1", "field2 value2"]} + ic_obj = ActionNode.create_model_class("system_design", out_mapping) + msg2 = Message( + role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign + ) + + memory = Memory() + memory.add_batch([msg1, msg2]) + ser_data = memory.dict() + + new_memory = Memory(**ser_data) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(2)[0] + assert isinstance(new_msg2, BaseModel) + assert isinstance(new_memory.storage[-1], BaseModel) + assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign) + assert new_msg2.role == "Boss" + + +def test_memory_serdeser_save(): + msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement) + + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionNode.create_model_class("system_design", out_mapping) + msg2 = Message( + role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign + ) + + memory = Memory() + memory.add_batch([msg1, msg2]) + + stg_path = serdeser_path.joinpath("team", "environment") + memory.serialize(stg_path) + assert stg_path.joinpath("memory.json").exists() + + new_memory = Memory.deserialize(stg_path) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(1)[0] + assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] + assert new_msg2.cause_by == any_to_str(WriteDesign) + assert len(new_memory.index) == 2 + + stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py new file mode 100644 index 000000000..b65e329d1 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:07 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions.action import Action +from metagpt.roles.product_manager import ProductManager +from metagpt.schema import Message + + +@pytest.mark.asyncio +async def test_product_manager_deserialize(): + role = ProductManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProductManager(**ser_role_dict) + + assert new_role.name == "Alice" + assert len(new_role._actions) == 2 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run([Message(content="write a cli snake game")]) diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py new file mode 100644 index 000000000..e52e3f247 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:06 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions.action import Action +from metagpt.actions.project_management import WriteTasks +from metagpt.roles.project_manager import ProjectManager + + +def test_project_manager_serialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_project_manager_deserialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + + new_role = ProjectManager(**ser_role_dict) + assert new_role.name == "Eve" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + assert isinstance(new_role._actions[0], WriteTasks) + # await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py new file mode 100644 index 000000000..72da8a6fc --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 4:49 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import shutil + +import pytest + +from metagpt.actions import WriteCode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.const import SERDESER_PATH +from metagpt.logs import logger +from metagpt.roles.engineer import Engineer +from metagpt.roles.product_manager import ProductManager +from metagpt.roles.role import Role +from metagpt.schema import Message +from metagpt.utils.common import format_trackback_info +from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + RoleA, + RoleB, + RoleC, + serdeser_path, +) + + +def test_roles(): + role_a = RoleA() + assert len(role_a._rc.watch) == 1 + role_b = RoleB() + assert len(role_a._rc.watch) == 1 + assert len(role_b._rc.watch) == 1 + + +def test_role_serialize(): + role = Role() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +def test_engineer_serialize(): + role = Engineer() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_engineer_deserialize(): + role = Engineer(use_code_review=True) + ser_role_dict = role.dict(by_alias=True) + + new_role = Engineer(**ser_role_dict) + assert new_role.name == "Alex" + assert new_role.use_code_review is True + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], WriteCode) + # await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + + +def test_role_serdeser_save(): + stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles") + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +@pytest.mark.asyncio +async def test_role_serdeser_interrupt(): + role_c = RoleC() + shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) + + stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}") + try: + await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement)) + except Exception: + logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") + role_c.serialize(stg_path) + + assert role_c._rc.memory.count() == 1 + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + + with pytest.raises(Exception): + await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement)) diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py new file mode 100644 index 000000000..0358265a9 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of schema ser&deser + +from metagpt.actions.action_node import ActionNode +from metagpt.actions.write_code import WriteCode +from metagpt.schema import Message +from metagpt.utils.common import any_to_str +from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage + + +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionNode.create_model_class("code", out_mapping) + + message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode) + ser_data = message.dict() + assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode" + assert ser_data["instruct_content"]["class"] == "code" + + new_message = Message(**ser_data) + assert new_message.cause_by == any_to_str(WriteCode) + assert new_message.cause_by in [any_to_str(WriteCode)] + assert new_message.instruct_content == ic_obj(**out_data) + + +def test_message_without_postprocess(): + """to explain `instruct_content` should be postprocessed""" + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionNode.create_model_class("code", out_mapping) + message = MockMessage(content="code", instruct_content=ic_obj(**out_data)) + ser_data = message.dict() + assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]} + + new_message = MockMessage(**ser_data) + assert new_message.instruct_content != ic_obj(**out_data) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py new file mode 100644 index 000000000..a66813489 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base test actions / roles used in unittest + +import asyncio +from pathlib import Path + +from pydantic import BaseModel, Field + +from metagpt.actions import Action, ActionOutput +from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.roles.role import Role, RoleReactMode + +serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage") + + +class MockMessage(BaseModel): + """to test normal dict without postprocess""" + + content: str = "" + instruct_content: BaseModel = Field(default=None) + + +class ActionPass(Action): + name: str = Field(default="ActionPass") + + async def run(self, messages: list["Message"]) -> ActionOutput: + await asyncio.sleep(5) # sleep to make other roles can watch the executed Message + output_mapping = {"result": (str, ...)} + pass_class = ActionNode.create_model_class("pass", output_mapping) + pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"})) + + return pass_output + + +class ActionOK(Action): + name: str = Field(default="ActionOK") + + async def run(self, messages: list["Message"]) -> str: + await asyncio.sleep(5) + return "ok" + + +class ActionRaise(Action): + name: str = Field(default="ActionRaise") + + async def run(self, messages: list["Message"]) -> str: + raise RuntimeError("parse error in ActionRaise") + + +class RoleA(Role): + name: str = Field(default="RoleA") + profile: str = Field(default="Role A") + goal: str = "RoleA's goal" + constraints: str = "RoleA's constraints" + + def __init__(self, **kwargs): + super(RoleA, self).__init__(**kwargs) + self._init_actions([ActionPass]) + self._watch([UserRequirement]) + + +class RoleB(Role): + name: str = Field(default="RoleB") + profile: str = Field(default="Role B") + goal: str = "RoleB's goal" + constraints: str = "RoleB's constraints" + + def __init__(self, **kwargs): + super(RoleB, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([ActionPass]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + +class RoleC(Role): + name: str = Field(default="RoleC") + profile: str = Field(default="Role C") + goal: str = "RoleC's goal" + constraints: str = "RoleC's constraints" + + def __init__(self, **kwargs): + super(RoleC, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([UserRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py new file mode 100644 index 000000000..dc41fa4ed --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# @Date : 11/27/2023 10:07 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import shutil + +import pytest + +from metagpt.const import SERDESER_PATH +from metagpt.logs import logger +from metagpt.roles import Architect, ProductManager, ProjectManager +from metagpt.team import Team +from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + ActionOK, + RoleA, + RoleB, + RoleC, + serdeser_path, +) + + +def test_team_deserialize(): + company = Team() + + pm = ProductManager() + arch = Architect() + company.hire( + [ + pm, + arch, + ProjectManager(), + ] + ) + assert len(company.env.get_roles()) == 3 + ser_company = company.dict() + new_company = Team(**ser_company) + + assert len(new_company.env.get_roles()) == 3 + assert new_company.env.get_role(pm.profile) is not None + + new_pm = new_company.env.get_role(pm.profile) + assert type(new_pm) == ProductManager + assert new_company.env.get_role(pm.profile) is not None + assert new_company.env.get_role(arch.profile) is not None + + +def test_team_serdeser_save(): + company = Team() + company.hire([RoleC()]) + + stg_path = serdeser_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team.deserialize(stg_path) + + assert len(new_company.env.roles) == 1 + + +@pytest.mark.asyncio +async def test_team_recover(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + role_c = RoleC() + company.hire([role_c]) + company.run_project(idea) + await company.run(n_round=4) + + ser_data = company.dict() + new_company = Team(**ser_data) + + new_role_c = new_company.env.get_role(role_c.profile) + # assert new_role_c._rc.memory == role_c._rc.memory # TODO + assert new_role_c._rc.env != role_c._rc.env # TODO + assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK + + new_company.run_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_save(): + idea = "write a 2048 web game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + role_c = RoleC() + company.hire([role_c]) + company.run_project(idea) + await company.run(n_round=4) + + new_company = Team.deserialize(stg_path) + new_role_c = new_company.env.get_role(role_c.profile) + # assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env + assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=` + assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo` + assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news` + + new_company.run_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_multi_roles_save(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + role_a = RoleA() + role_b = RoleB() + + assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", "RoleA"} + assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", "RoleB"} + assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"} + + company = Team() + company.hire([role_a, role_b]) + company.run_project(idea) + await company.run(n_round=4) + + logger.info("Team recovered") + + new_company = Team.deserialize(stg_path) + new_company.run_project(idea) + + assert new_company.env.get_role(role_b.profile)._rc.state == 1 + + await new_company.run(n_round=4) diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py new file mode 100644 index 000000000..65b8f456a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 10:56 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import pytest + +from metagpt.actions import WriteCode +from metagpt.llm import LLM +from metagpt.schema import CodingContext, Document + + +def test_write_design_serialize(): + action = WriteCode() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCode" + # assert "llm" in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_write_code_deserialize(): + context = CodingContext( + filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers") + ) + doc = Document(content=context.json()) + action = WriteCode(context=doc) + serialized_data = action.dict() + new_action = WriteCode(**serialized_data) + + assert new_action.name == "WriteCode" + assert new_action.llm == LLM() + await action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py new file mode 100644 index 000000000..01026590c --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of WriteCodeReview SerDeser + +import pytest + +from metagpt.actions import WriteCodeReview +from metagpt.llm import LLM +from metagpt.schema import CodingContext, Document + + +@pytest.mark.asyncio +async def test_write_code_review_deserialize(): + code_content = """ +def div(a: int, b: int = 0): + return a / b +""" + context = CodingContext( + filename="test_op.py", + design_doc=Document(content="divide two numbers"), + code_doc=Document(content=code_content), + ) + + action = WriteCodeReview(context=context) + serialized_data = action.dict() + assert serialized_data["name"] == "WriteCodeReview" + + new_action = WriteCodeReview(**serialized_data) + + assert new_action.name == "WriteCodeReview" + assert new_action.llm == LLM() + await new_action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py new file mode 100644 index 000000000..4e768ddd7 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 8:19 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteDesign, WriteTasks +from metagpt.llm import LLM + + +def test_write_design_serialize(): + action = WriteDesign() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" in ser_action_dict # not export + + +def test_write_task_serialize(): + action = WriteTasks() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_write_design_deserialize(): + action = WriteDesign() + serialized_data = action.dict() + new_action = WriteDesign(**serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + await new_action.run(with_messages="write a cli snake game") + + +@pytest.mark.asyncio +async def test_write_task_deserialize(): + action = WriteTasks() + serialized_data = action.dict() + new_action = WriteTasks(**serialized_data) + assert new_action.name == "CreateTasks" + assert new_action.llm == LLM() + await new_action.run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_write_prd.py b/tests/metagpt/serialize_deserialize/test_write_prd.py new file mode 100644 index 000000000..d6d14f99a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_prd.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 1:47 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import pytest + +from metagpt.actions import WritePRD +from metagpt.llm import LLM +from metagpt.schema import Message + + +def test_action_serialize(): + action = WritePRD() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = WritePRD() + serialized_data = action.dict() + new_action = WritePRD(**serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + action_output = await new_action.run(with_messages=Message(content="write a cli snake game")) + assert len(action_output.content) > 0 diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index bc88eb742..3a899d6ff 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -8,6 +8,8 @@ """ +from pathlib import Path + import pytest from metagpt.actions import UserRequirement @@ -17,6 +19,8 @@ from metagpt.logs import logger from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message +serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage") + @pytest.fixture def env(): @@ -52,6 +56,7 @@ async def test_publish_and_process_message(env: Environment): ) env.add_roles([product_manager, architect]) + env.publish_message(Message(role="User", content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement)) await env.run(k=2) logger.info(f"{env.history=}") diff --git a/tests/metagpt/test_message.py b/tests/metagpt/test_message.py index 04d85d9e4..8f267ba54 100644 --- a/tests/metagpt/test_message.py +++ b/tests/metagpt/test_message.py @@ -23,7 +23,7 @@ def test_all_messages(): UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role="QA"), + Message(content=test_content, role="QA"), ] for msg in msgs: assert msg.content == test_content diff --git a/tests/metagpt/test_prompt.py b/tests/metagpt/test_prompt.py new file mode 100644 index 000000000..f7b1cc68e --- /dev/null +++ b/tests/metagpt/test_prompt.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 14:45 +@Author : alexanderwu +@File : test_llm.py +""" + +import pytest + +from metagpt.llm import LLM + +CODE_REVIEW_SMALLEST_CONTEXT = """ +## game.js +```Code +// game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + +MOVE_DRAFT = """ +## move function draft + +```javascript +move(direction) { + let moved = false; + switch (direction) { + case 'up': + for (let c = 0; c < 4; c++) { + for (let r = 1; r < 4; r++) { + if (this.board[r][c] !== 0) { + let row = r; + while (row > 0 && this.board[row - 1][c] === 0) { + this.board[row - 1][c] = this.board[row][c]; + this.board[row][c] = 0; + row--; + moved = true; + } + if (row > 0 && this.board[row - 1][c] === this.board[row][c]) { + this.board[row - 1][c] *= 2; + this.board[row][c] = 0; + this.score += this.board[row - 1][c]; + moved = true; + } + } + } + } + break; + case 'down': + // Implement logic for moving tiles down + // Similar to the 'up' case but iterating in reverse order + // and checking for merging in the opposite direction + break; + case 'left': + // Implement logic for moving tiles left + // Similar to the 'up' case but iterating over columns first + // and checking for merging in the opposite direction + break; + case 'right': + // Implement logic for moving tiles right + // Similar to the 'up' case but iterating over columns in reverse order + // and checking for merging in the opposite direction + break; + } + + if (moved) { + this.addRandomTile(); + } +} +``` +""" + +FUNCTION_TO_MERMAID_CLASS = """ +## context +``` +class UIDesign(Action): + #Class representing the UI Design action. + def __init__(self, name, context=None, llm=None): + super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt + @parse + def parse_requirement(self, context: str): + #Parse UI Design draft from the context using regex. + pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR" + return context, pattern + @parse + def parse_ui_elements(self, context: str): + #Parse Selected Elements from the context using regex. + pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout" + return context, pattern + @parse + def parse_css_code(self, context: str): + pattern = r"```css.*?\n(.*?)## Anything UNCLEAR" + return context, pattern + @parse + def parse_html_code(self, context: str): + pattern = r"```html.*?\n(.*?)```" + return context, pattern + async def draw_icons(self, context, *args, **kwargs): + #Draw icons using SDEngine. + engine = SDEngine() + icon_prompts = self.parse_ui_elements(context) + icons = icon_prompts.split("\n") + icons = [s for s in icons if len(s.strip()) > 0] + prompts_batch = [] + for icon_prompt in icons: + # fixme: 添加icon lora + prompt = engine.construct_payload(icon_prompt + ".") + prompts_batch.append(prompt) + await engine.run_t2i(prompts_batch) + logger.info("Finish icon design using StableDiffusion API") + async def _save(self, css_content, html_content): + save_dir = CONFIG.workspace_path / "resources" / "codes" + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + # Save CSS and HTML content to files + css_file_path = save_dir / "ui_design.css" + html_file_path = save_dir / "ui_design.html" + with open(css_file_path, "w") as css_file: + css_file.write(css_content) + with open(html_file_path, "w") as html_file: + html_file.write(html_content) + async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput: + #Run the UI Design action. + # fixme: update prompt (根据需求细化prompt) + context = requirements[-1].content + ui_design_draft = self.parse_requirement(context=context) + # todo: parse requirements str + prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE) + logger.info(prompt) + ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING) + logger.info(ui_describe.content) + logger.info(ui_describe.instruct_content) + css = self.parse_css_code(context=ui_describe.content) + html = self.parse_html_code(context=ui_describe.content) + await self._save(css_content=css, html_content=html) + await self.draw_icons(ui_describe.content) + return ui_describe +``` +----- +## format example +[CONTENT] +{ + "ClassView": "classDiagram\n class A {\n -int x\n +int y\n -int speed\n -int direction\n +__init__(x: int, y: int, speed: int, direction: int)\n +change_direction(new_direction: int) None\n +move() None\n }\n " +} +[/CONTENT] +## nodes: ": # " +- ClassView: # Generate the mermaid class diagram corresponding to source code in "context." +## constraint +- Language: Please use the same language as the user input. +- Format: output wrapped inside [CONTENT][/CONTENT] as format example, nothing else. +## action +Fill in the above nodes(ClassView) based on the format example. +""" + +MOVE_FUNCTION = """ +## move function implementation + +```javascript +move(direction) { + let moved = false; + switch (direction) { + case 'up': + for (let c = 0; c < 4; c++) { + for (let r = 1; r < 4; r++) { + if (this.board[r][c] !== 0) { + let row = r; + while (row > 0 && this.board[row - 1][c] === 0) { + this.board[row - 1][c] = this.board[row][c]; + this.board[row][c] = 0; + row--; + moved = true; + } + if (row > 0 && this.board[row - 1][c] === this.board[row][c]) { + this.board[row - 1][c] *= 2; + this.board[row][c] = 0; + this.score += this.board[row - 1][c]; + moved = true; + } + } + } + } + break; + case 'down': + for (let c = 0; c < 4; c++) { + for (let r = 2; r >= 0; r--) { + if (this.board[r][c] !== 0) { + let row = r; + while (row < 3 && this.board[row + 1][c] === 0) { + this.board[row + 1][c] = this.board[row][c]; + this.board[row][c] = 0; + row++; + moved = true; + } + if (row < 3 && this.board[row + 1][c] === this.board[row][c]) { + this.board[row + 1][c] *= 2; + this.board[row][c] = 0; + this.score += this.board[row + 1][c]; + moved = true; + } + } + } + } + break; + case 'left': + for (let r = 0; r < 4; r++) { + for (let c = 1; c < 4; c++) { + if (this.board[r][c] !== 0) { + let col = c; + while (col > 0 && this.board[r][col - 1] === 0) { + this.board[r][col - 1] = this.board[r][col]; + this.board[r][col] = 0; + col--; + moved = true; + } + if (col > 0 && this.board[r][col - 1] === this.board[r][col]) { + this.board[r][col - 1] *= 2; + this.board[r][col] = 0; + this.score += this.board[r][col - 1]; + moved = true; + } + } + } + } + break; + case 'right': + for (let r = 0; r < 4; r++) { + for (let c = 2; c >= 0; c--) { + if (this.board[r][c] !== 0) { + let col = c; + while (col < 3 && this.board[r][col + 1] === 0) { + this.board[r][col + 1] = this.board[r][col]; + this.board[r][col] = 0; + col++; + moved = true; + } + if (col < 3 && this.board[r][col + 1] === this.board[r][col]) { + this.board[r][col + 1] *= 2; + this.board[r][col] = 0; + this.score += this.board[r][col + 1]; + moved = true; + } + } + } + } + break; + } + + if (moved) { + this.addRandomTile(); + } +} +``` +""" + + +@pytest.fixture() +def llm(): + return LLM() + + +@pytest.mark.asyncio +async def test_llm_code_review(llm): + choices = [ + "Please review the move function code above. Should it be refactor?", + "Please implement the move function", + "Please write a draft for the move function in order to implement it", + ] + # prompt = CODE_REVIEW_SMALLEST_CONTEXT+ "\n\n" + MOVE_DRAFT + "\n\n" + choices[1] + # rsp = await llm.aask(prompt) + + prompt = CODE_REVIEW_SMALLEST_CONTEXT + "\n\n" + MOVE_FUNCTION + "\n\n" + choices[0] + prompt = FUNCTION_TO_MERMAID_CLASS + + _ = await llm.aask(prompt) + + +# if __name__ == "__main__": +# pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index a6c84d32a..897d203c7 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -7,11 +7,12 @@ @Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for the utilization of the new feature of `Message` class. """ + import json -import pytest - from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode +from metagpt.actions.write_code import WriteCode from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.utils.common import any_to_str @@ -19,10 +20,10 @@ from metagpt.utils.common import any_to_str def test_messages(): test_content = "test_message" msgs = [ - UserMessage(test_content), - SystemMessage(test_content), - AIMessage(test_content), - Message(test_content, role="QA"), + UserMessage(content=test_content), + SystemMessage(content=test_content), + AIMessage(content=test_content), + Message(content=test_content, role="QA"), ] text = str(msgs) roles = ["user", "system", "assistant", "QA"] @@ -30,7 +31,7 @@ def test_messages(): def test_message(): - m = Message("a", role="v1") + m = Message(content="a", role="v1") v = m.dump() d = json.loads(v) assert d @@ -43,7 +44,7 @@ def test_message(): assert m.content == "a" assert m.role == "v2" - m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") assert m.content == "a" assert m.role == "b" assert m.send_to == {"c"} @@ -60,12 +61,35 @@ def test_message(): def test_routes(): - m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") m.send_to = "b" assert m.send_to == {"b"} m.send_to = {"e", Action} assert m.send_to == {"e", any_to_str(Action)} -if __name__ == "__main__": - pytest.main([__file__, "-s"]) +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionNode.create_model_class("code", out_mapping) + + message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode) + message_dict = message.dict() + assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode" + assert message_dict["instruct_content"] == { + "class": "code", + "mapping": {"field3": "(, Ellipsis)", "field4": "(list[str], Ellipsis)"}, + "value": {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}, + } + + new_message = Message(**message_dict) + assert new_message.content == message.content + assert new_message.instruct_content == message.instruct_content + assert new_message.cause_by == message.cause_by + assert new_message.instruct_content.field3 == out_data["field3"] + + message = Message(content="code") + message_dict = message.dict() + new_message = Message(**message_dict) + assert new_message.instruct_content is None + assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement" diff --git a/tests/metagpt/test_subscription.py b/tests/metagpt/test_subscription.py index 1399df7fe..b902d5416 100644 --- a/tests/metagpt/test_subscription.py +++ b/tests/metagpt/test_subscription.py @@ -13,12 +13,12 @@ async def test_subscription_run(): async def trigger(): while True: - yield Message("the latest news about OpenAI") + yield Message(content="the latest news about OpenAI") await asyncio.sleep(3600 * 24) class MockRole(Role): async def run(self, message=None): - return Message("") + return Message(content="") async def callback(message): nonlocal callback_done @@ -61,11 +61,11 @@ async def test_subscription_run(): async def test_subscription_run_error(loguru_caplog): async def trigger1(): while True: - yield Message("the latest news about OpenAI") + yield Message(content="the latest news about OpenAI") await asyncio.sleep(3600 * 24) async def trigger2(): - yield Message("the latest news about OpenAI") + yield Message(content="the latest news about OpenAI") class MockRole1(Role): async def run(self, message=None): @@ -73,7 +73,7 @@ async def test_subscription_run_error(loguru_caplog): class MockRole2(Role): async def run(self, message=None): - return Message("") + return Message(content="") async def callback(msg: Message): print(msg) diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py new file mode 100644 index 000000000..930306b5e --- /dev/null +++ b/tests/metagpt/test_team.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of team + +from metagpt.roles.project_manager import ProjectManager +from metagpt.team import Team + + +def test_team(): + company = Team() + company.hire([ProjectManager()]) + + assert len(company.environment.roles) == 1 diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 4bd38db63..0ab34437d 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -47,7 +47,7 @@ class TestGetProjectRoot: Input(x=RunCode, want="metagpt.actions.run_code.RunCode"), Input(x=RunCode(), want="metagpt.actions.run_code.RunCode"), Input(x=Message, want="metagpt.schema.Message"), - Input(x=Message(""), want="metagpt.schema.Message"), + Input(x=Message(content=""), want="metagpt.schema.Message"), Input(x="A", want="A"), ] for i in inputs: diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index ffa34866c..f027d53f8 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -7,7 +7,7 @@ from typing import List, Tuple from metagpt.actions import WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, @@ -54,7 +54,7 @@ def test_actionoutout_schema_to_mapping(): def test_serialize_and_deserialize_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) message = Message( content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD