Merge branch 'dev' into code_intepreter

This commit is contained in:
yzlin 2024-01-31 00:08:09 +08:00
commit 2fcb2a1cfe
282 changed files with 6993 additions and 3210 deletions

34
.github/workflows/build-package.yaml vendored Normal file
View file

@ -0,0 +1,34 @@
name: Build and upload python package
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -e.
pip install setuptools wheel twine
- name: Set package version
run: |
export VERSION="${GITHUB_REF#refs/tags/v}"
sed -i "s/version=.*/version=\"${VERSION}\",/" setup.py
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python setup.py bdist_wheel sdist
twine upload dist/*

View file

@ -50,6 +50,7 @@ jobs:
run: |
export ALLOW_OPENAI_API_CALL=0
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

1
.gitignore vendored
View file

@ -176,5 +176,6 @@ htmlcov.*
cov.xml
*.dot
*.pkl
*.faiss
*-structure.csv
*-structure.json

View file

@ -6,16 +6,16 @@ # MetaGPT: The Multi-Agent Framework
</p>
<p align="center">
<b>Assign different roles to GPTs to form a collaborative software entity for complex tasks.</b>
<b>Assign different roles to GPTs to form a collaborative entity for complex tasks.</b>
</p>
<p align="center">
<a href="docs/README_CN.md"><img src="https://img.shields.io/badge/文档-中文版-blue.svg" alt="CN doc"></a>
<a href="README.md"><img src="https://img.shields.io/badge/document-English-blue.svg" alt="EN doc"></a>
<a href="docs/README_JA.md"><img src="https://img.shields.io/badge/ドキュメント-日本語-blue.svg" alt="JA doc"></a>
<a href="https://discord.gg/DYn29wFk9z"><img src="https://dcbadge.vercel.app/api/server/DYn29wFk9z?style=flat" alt="Discord Follow"></a>
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-blue.svg" alt="License: MIT"></a>
<a href="docs/ROADMAP.md"><img src="https://img.shields.io/badge/ROADMAP-路线图-blue" alt="roadmap"></a>
<a href="https://discord.gg/DYn29wFk9z"><img src="https://dcbadge.vercel.app/api/server/DYn29wFk9z?style=flat" alt="Discord Follow"></a>
<a href="https://twitter.com/MetaGPT_"><img src="https://img.shields.io/twitter/follow/MetaGPT?style=social" alt="Twitter Follow"></a>
</p>
@ -25,19 +25,31 @@ # MetaGPT: The Multi-Agent Framework
<a href="https://huggingface.co/spaces/deepwisdom/MetaGPT" target="_blank"><img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20-Hugging%20Face-blue?color=blue&logoColor=white" /></a>
</p>
## News
🚀 Jan. 16, 2024: Our paper [MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework
](https://arxiv.org/abs/2308.00352) accepted for oral presentation **(top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category.
🚀 Jan. 03, 2024: [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0) released, new features include serialization, upgraded OpenAI package and supported multiple LLM, provided [minimal example for debate](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py) etc.
🚀 Dec. 15, 2023: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) released, introducing some experimental features such as **incremental development**, **multilingual**, **multiple programming languages**, etc.
🔥 Nov. 08, 2023: MetaGPT is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html).
🔥 Sep. 01, 2023: MetaGPT tops GitHub Trending Monthly for the **17th time** in August 2023.
🌟 Jun. 30, 2023: MetaGPT is now open source.
🌟 Apr. 24, 2023: First line of MetaGPT code committed.
## Software Company as Multi-Agent System
1. MetaGPT takes a **one line requirement** as input and outputs **user stories / competitive analysis / requirements / data structures / APIs / documents, etc.**
2. Internally, MetaGPT includes **product managers / architects / project managers / engineers.** It provides the entire process of a **software company along with carefully orchestrated SOPs.**
1. `Code = SOP(Team)` is the core philosophy. We materialize SOP and apply it to teams composed of LLMs.
![A software company consists of LLM-based roles](docs/resources/software_company_cd.jpeg)
<p align="center">Software Company Multi-Role Schematic (Gradually Implementing)</p>
## News
🚀 Jan 03: Here comes [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)! In this version, we added serialization and deserialization of important objects and enabled breakpoint recovery. We upgraded OpenAI package to v1.6.0 and supported Gemini, ZhipuAI, Ollama, OpenLLM, etc. Moreover, we provided extremely simple examples where you need only 7 lines to implement a general election [debate](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py). Check out more details [here](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)!
🚀 Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduced **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or existing codebase. We also launched a whole collection of important features, including **multilingual support** (experimental), multiple **programming languages support** (experimental), **incremental development** (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism!
<p align="center">Software Company Multi-Agent Schematic (Gradually Implementing)</p>
## Install

View file

@ -1,149 +0,0 @@
# DO NOT MODIFY THIS FILE, create a new key.yaml, define OPENAI_API_KEY.
# The configuration of key.yaml has a higher priority and will not enter git
#### Project Path Setting
# WORKSPACE_PATH: "Path for placing output files"
#### if OpenAI
## The official OPENAI_BASE_URL is https://api.openai.com/v1
## If the official OPENAI_BASE_URL is not available, we recommend using the [openai-forward](https://github.com/beidongjiedeguang/openai-forward).
## Or, you can configure OPENAI_PROXY to access official OPENAI_BASE_URL.
OPENAI_BASE_URL: "https://api.openai.com/v1"
#OPENAI_PROXY: "http://127.0.0.1:8118"
#OPENAI_API_KEY: "YOUR_API_KEY" # set the value to sk-xxx if you host the openai interface for open llm model
OPENAI_API_MODEL: "gpt-4-1106-preview"
MAX_TOKENS: 4096
RPM: 10
TIMEOUT: 60 # Timeout for llm invocation
#DEFAULT_PROVIDER: openai
#### if Spark
#SPARK_APPID : "YOUR_APPID"
#SPARK_API_SECRET : "YOUR_APISecret"
#SPARK_API_KEY : "YOUR_APIKey"
#DOMAIN : "generalv2"
#SPARK_URL : "ws://spark-api.xf-yun.com/v2.1/chat"
#### if Anthropic
#ANTHROPIC_API_KEY: "YOUR_API_KEY"
#### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb
#OPENAI_API_TYPE: "azure"
#OPENAI_BASE_URL: "YOUR_AZURE_ENDPOINT"
#OPENAI_API_KEY: "YOUR_AZURE_API_KEY"
#OPENAI_API_VERSION: "YOUR_AZURE_API_VERSION"
#DEPLOYMENT_NAME: "YOUR_DEPLOYMENT_NAME"
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`.
#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY"
# GEMINI_API_KEY: "YOUR_API_KEY"
#### if use self-host open llm model with openai-compatible interface
#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1"
#OPEN_LLM_API_MODEL: "llama2-13b"
#
##### if use Fireworks api
#FIREWORKS_API_KEY: "YOUR_API_KEY"
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
#### if use self-host open llm model by ollama
# OLLAMA_API_BASE: http://127.0.0.1:11434/api
# OLLAMA_API_MODEL: llama2
#### for Search
## Supported values: serpapi/google/serper/ddg
#SEARCH_ENGINE: serpapi
## Visit https://serpapi.com/ to get key.
#SERPAPI_API_KEY: "YOUR_API_KEY"
## Visit https://console.cloud.google.com/apis/credentials to get key.
#GOOGLE_API_KEY: "YOUR_API_KEY"
## Visit https://programmablesearchengine.google.com/controlpanel/create to get id.
#GOOGLE_CSE_ID: "YOUR_CSE_ID"
## Visit https://serper.dev/ to get key.
#SERPER_API_KEY: "YOUR_API_KEY"
#### for web access
## Supported values: playwright/selenium
#WEB_BROWSER_ENGINE: playwright
## Supported values: chromium/firefox/webkit, visit https://playwright.dev/python/docs/api/class-browsertype
##PLAYWRIGHT_BROWSER_TYPE: chromium
## Supported values: chrome/firefox/edge/ie, visit https://www.selenium.dev/documentation/webdriver/browsers/
# SELENIUM_BROWSER_TYPE: chrome
#### for TTS
#AZURE_TTS_SUBSCRIPTION_KEY: "YOUR_API_KEY"
#AZURE_TTS_REGION: "eastus"
#### for OPENAI VISION
#OPENAI_VISION_MODEL: "YOUR_VISION_MODEL_NAME"
#VISION_MAX_TOKENS: 4096
#### for Stable Diffusion
## Use SD service, based on https://github.com/AUTOMATIC1111/stable-diffusion-webui
#SD_URL: "YOUR_SD_URL"
#SD_T2I_API: "/sdapi/v1/txt2img"
#### for Execution
#LONG_TERM_MEMORY: false
#### for Mermaid CLI
## If you installed mmdc (Mermaid CLI) only for metagpt then enable the following configuration.
#PUPPETEER_CONFIG: "./config/puppeteer-config.json"
#MMDC: "./node_modules/.bin/mmdc"
### for calc_usage
# CALC_USAGE: false
### for Research
# MODEL_FOR_RESEARCHER_SUMMARY: gpt-3.5-turbo
# MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
### choose the engine for mermaid conversion,
# default is nodejs, you can change it to playwright,pyppeteer or ink
# MERMAID_ENGINE: nodejs
### browser path for pyppeteer engine, support Chrome, Chromium,MS Edge
#PYPPETEER_EXECUTABLE_PATH: "/usr/bin/google-chrome-stable"
### for repair non-openai LLM's output when parse json-text if PROMPT_FORMAT=json
### due to non-openai LLM's output will not always follow the instruction, so here activate a post-process
### repair operation on the content extracted from LLM's raw output. Warning, it improves the result but not fix all cases.
# REPAIR_LLM_OUTPUT: false
# PROMPT_FORMAT: json #json or markdown
### Agent configurations
# RAISE_NOT_CONFIG_ERROR: true # "true" if the LLM key is not configured, throw a NotConfiguredException, else "false".
# WORKSPACE_PATH_WITH_UID: false # "true" if using `{workspace}/{uid}` as the workspace path; "false" use `{workspace}`.
### Meta Models
#METAGPT_TEXT_TO_IMAGE_MODEL: MODEL_URL
### S3 config
#S3_ACCESS_KEY: "YOUR_S3_ACCESS_KEY"
#S3_SECRET_KEY: "YOUR_S3_SECRET_KEY"
#S3_ENDPOINT_URL: "YOUR_S3_ENDPOINT_URL"
#S3_SECURE: true # true/false
#S3_BUCKET: "YOUR_S3_BUCKET"
### Redis config
#REDIS_HOST: "YOUR_REDIS_HOST"
#REDIS_PORT: "YOUR_REDIS_PORT"
#REDIS_PASSWORD: "YOUR_REDIS_PASSWORD"
#REDIS_DB: "YOUR_REDIS_DB_INDEX, str, 0-based"
# DISABLE_LLM_PROVIDER_CHECK: false

3
config/config2.yaml Normal file
View file

@ -0,0 +1,3 @@
llm:
api_key: "YOUR_API_KEY"
model: "gpt-3.5-turbo-1106"

View file

@ -0,0 +1,42 @@
llm:
api_type: "openai"
base_url: "YOUR_BASE_URL"
api_key: "YOUR_API_KEY"
model: "gpt-3.5-turbo-1106" # or gpt-4-1106-preview
proxy: "YOUR_PROXY"
search:
api_type: "google"
api_key: "YOUR_API_KEY"
cse_id: "YOUR_CSE_ID"
mermaid:
engine: "pyppeteer"
path: "/Applications/Google Chrome.app"
redis:
host: "YOUR_HOST"
port: 32582
password: "YOUR_PASSWORD"
db: "0"
s3:
access_key: "YOUR_ACCESS_KEY"
secret_key: "YOUR_SECRET_KEY"
endpoint: "YOUR_ENDPOINT"
secure: false
bucket: "test"
AZURE_TTS_SUBSCRIPTION_KEY: "YOUR_SUBSCRIPTION_KEY"
AZURE_TTS_REGION: "eastus"
IFLYTEK_APP_ID: "YOUR_APP_ID"
IFLYTEK_API_KEY: "YOUR_API_KEY"
IFLYTEK_API_SECRET: "YOUR_API_SECRET"
METAGPT_TEXT_TO_IMAGE_MODEL_URL: "YOUR_MODEL_URL"
PYPPETEER_EXECUTABLE_PATH: "/Applications/Google Chrome.app"

View file

@ -9,24 +9,22 @@ ### Short-term Objective
1. Become the multi-agent framework with the highest ROI.
2. Support fully automatic implementation of medium-sized projects (around 2000 lines of code).
3. Implement most identified tasks, reaching version 0.5.
3. Implement most identified tasks, reaching version 1.0.
### Tasks
To reach version v0.5, approximately 70% of the following tasks need to be completed.
1. Usability
1. ~~Release v0.01 pip package to try to solve issues like npm installation (though not necessarily successfully)~~ (v0.3.0)
2. Support for overall save and recovery of software companies
2. ~~Support for overall save and recovery of software companies~~ (v0.6.0)
3. ~~Support human confirmation and modification during the process~~ (v0.3.0) New: Support human confirmation and modification with fewer constrainsts and a more user-friendly interface
4. Support process caching: Consider carefully whether to add server caching mechanism
5. ~~Resolve occasional failure to follow instruction under current prompts, causing code parsing errors, through stricter system prompts~~ (v0.4.0, with function call)
6. Write documentation, describing the current features and usage at all levels (ongoing, continuously adding contents to [documentation site](https://docs.deepwisdom.ai/main/en/guide/get_started/introduction.html))
7. ~~Support Docker~~
2. Features
1. Support a more standard and stable parser (need to analyze the format that the current LLM is better at)
2. ~~Establish a separate output queue, differentiated from the message queue~~
3. Attempt to atomize all role work, but this may significantly increase token overhead
1. ~~Support a more standard and stable parser (need to analyze the format that the current LLM is better at)~~ (v0.5.0)
2. ~~Establish a separate output queue, differentiated from the message queue~~ (v0.5.0)
3. ~~Attempt to atomize all role work, but this may significantly increase token overhead~~ (v0.5.0)
4. Complete the design and implementation of module breakdown
5. Support various modes of memory: clearly distinguish between long-term and short-term memory
6. Perfect the test role, and carry out necessary interactions with humans
@ -43,10 +41,10 @@ ### Tasks
4. Actions
1. ~~Implementation: Search~~ (v0.2.1)
2. Implementation: Knowledge search, supporting 10+ data formats
3. Implementation: Data EDA (expected v0.6.0)
4. Implementation: Review
5. ~~Implementation~~: Add Document (v0.5.0)
6. ~~Implementation~~: Delete Document (v0.5.0)
3. Implementation: Data EDA (expected v0.7.0)
4. Implementation: Review & Revise (expected v0.7.0)
5. ~~Implementation: Add Document~~ (v0.5.0)
6. ~~Implementation: Delete Document~~ (v0.5.0)
7. Implementation: Self-training
8. ~~Implementation: DebugError~~ (v0.2.1)
9. Implementation: Generate reliable unit tests based on YAPI
@ -64,15 +62,14 @@ ### Tasks
3. ~~Support Playwright apis~~
7. Roles
1. Perfect the action pool/skill pool for each role
2. Red Book blogger
3. E-commerce seller
4. Data analyst (expected v0.6.0)
5. News observer
6. ~~Institutional researcher~~ (v0.2.1)
2. E-commerce seller
3. Data analyst (expected v0.7.0)
4. News observer
5. ~~Institutional researcher~~ (v0.2.1)
8. Evaluation
1. Support an evaluation on a game dataset (experimentation done with game agents)
2. Reproduce papers, implement full skill acquisition for a single game role, achieving SOTA results (experimentation done with game agents)
3. Support an evaluation on a math dataset (expected v0.6.0)
3. Support an evaluation on a math dataset (expected v0.7.0)
4. Reproduce papers, achieving SOTA results for current mathematical problem solving process
9. LLM
1. Support Claude underlying API
@ -80,7 +77,7 @@ ### Tasks
3. Support streaming version of all APIs
4. ~~Make gpt-3.5-turbo available (HARD)~~
10. Other
1. Clean up existing unused code
2. Unify all code styles and establish contribution standards
3. Multi-language support
4. Multi-programming-language support
1. ~~Clean up existing unused code~~
2. ~~Unify all code styles and establish contribution standards~~
3. ~~Multi-language support~~
4. ~~Multi-programming-language support~~

View file

@ -6,7 +6,7 @@ Author: garylin2099
import re
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.roles import Role
@ -48,8 +48,8 @@ class CreateAgent(Action):
pattern = r"```python(.*)```"
match = re.search(pattern, rsp, re.DOTALL)
code_text = match.group(1) if match else ""
CONFIG.workspace_path.mkdir(parents=True, exist_ok=True)
new_file = CONFIG.workspace_path / "agent_created_agent.py"
config.workspace.path.mkdir(parents=True, exist_ok=True)
new_file = config.workspace.path / "agent_created_agent.py"
new_file.write_text(code_text)
return code_text
@ -61,7 +61,7 @@ class AgentCreator(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([CreateAgent])
self.set_actions([CreateAgent])
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")

View file

@ -57,7 +57,7 @@ class SimpleCoder(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([SimpleWriteCode])
self.set_actions([SimpleWriteCode])
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
@ -76,7 +76,7 @@ class RunnableCoder(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([SimpleWriteCode, SimpleRunCode])
self.set_actions([SimpleWriteCode, SimpleRunCode])
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
async def _act(self) -> Message:

View file

@ -46,7 +46,7 @@ class SimpleCoder(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._watch([UserRequirement])
self._init_actions([SimpleWriteCode])
self.set_actions([SimpleWriteCode])
class SimpleWriteTest(Action):
@ -75,7 +75,7 @@ class SimpleTester(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([SimpleWriteTest])
self.set_actions([SimpleWriteTest])
# self._watch([SimpleWriteCode])
self._watch([SimpleWriteCode, SimpleWriteReview]) # feel free to try this too
@ -114,7 +114,7 @@ class SimpleReviewer(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([SimpleWriteReview])
self.set_actions([SimpleWriteReview])
self._watch([SimpleWriteTest])

View file

@ -49,7 +49,7 @@ class Debator(Role):
def __init__(self, **data: Any):
super().__init__(**data)
self._init_actions([SpeakAloud])
self.set_actions([SpeakAloud])
self._watch([UserRequirement, SpeakAloud])
async def _observe(self) -> int:

View file

@ -13,7 +13,9 @@ from metagpt.roles import Role
from metagpt.team import Team
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action1.llm.model = "gpt-4-1106-preview"
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
action2.llm.model = "gpt-3.5-turbo-1106"
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
env = Environment(desc="US election live broadcast")

Binary file not shown.

Binary file not shown.

View file

@ -23,6 +23,10 @@ async def main():
# streaming mode, much slower
await llm.acompletion_text(hello_msg, stream=True)
# check completion if exist to test llm complete functions
if hasattr(llm, "completion"):
logger.info(llm.completion(hello_msg))
if __name__ == "__main__":
asyncio.run(main())

View file

@ -8,7 +8,7 @@ import asyncio
from langchain.embeddings import OpenAIEmbeddings
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import DATA_PATH, EXAMPLE_PATH
from metagpt.document_store import FaissStore
from metagpt.logs import logger
@ -16,7 +16,8 @@ from metagpt.roles import Sales
def get_store():
embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url)
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
return FaissStore(DATA_PATH / "example.json", embedding=embedding)

View file

@ -10,41 +10,67 @@ from __future__ import annotations
from typing import Optional, Union
from pydantic import ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.context_mixin import ContextMixin
from metagpt.schema import (
CodePlanAndChangeContext,
CodeSummarizeContext,
CodingContext,
RunCodeContext,
SerializationMixin,
TestingContext,
)
from metagpt.utils.project_repo import ProjectRepo
class Action(SerializationMixin, is_polymorphic_base=True):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
class Action(SerializationMixin, ContextMixin, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str = ""
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
i_context: Union[
dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, CodePlanAndChangeContext, str, None
] = ""
prefix: str = "" # aask*时会加上prefix作为system_message
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
@property
def repo(self) -> ProjectRepo:
if not self.context.repo:
self.context.repo = ProjectRepo(self.context.git_repo)
return self.context.repo
@property
def prompt_schema(self):
return self.config.prompt_schema
@property
def project_name(self):
return self.config.project_name
@project_name.setter
def project_name(self, value):
self.config.project_name = value
@property
def project_path(self):
return self.config.project_path
@model_validator(mode="before")
@classmethod
def set_name_if_empty(cls, values):
if "name" not in values or not values["name"]:
values["name"] = cls.__name__
return values
@model_validator(mode="before")
@classmethod
def _init_with_instruction(cls, values):
if "instruction" in values:
name = values["name"]
i = values["instruction"]
i = values.pop("instruction")
values["node"] = ActionNode(key=name, expected_type=str, instruction=i, example="", schema="raw")
return values

View file

@ -9,16 +9,30 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
"""
import json
from typing import Any, Dict, List, Optional, Tuple, Type
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, create_model, model_validator
from pydantic import BaseModel, Field, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from metagpt.utils.common import OutputParser, general_after_log
from metagpt.utils.human_interaction import HumanInteraction
class ReviewMode(Enum):
HUMAN = "human"
AUTO = "auto"
class ReviseMode(Enum):
HUMAN = "human" # human revise
HUMAN_REVIEW = "human_review" # human-review and auto-revise
AUTO = "auto" # auto-review and auto-revise
TAG = "CONTENT"
@ -45,6 +59,58 @@ SIMPLE_TEMPLATE = """
Follow instructions of nodes, generate output and make sure it follows the format example.
"""
REVIEW_TEMPLATE = """
## context
Compare the key's value of nodes_output and the corresponding requirements one by one. If a key's value that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
### nodes_output
{nodes_output}
-----
## format example
[{tag}]
{{
"key1": "comment1",
"key2": "comment2",
"keyn": "commentn"
}}
[/{tag}]
## nodes: "<node>: <type> # <instruction>"
- key1: <class \'str\'> # the first key name of mismatch key
- key2: <class \'str\'> # the second key name of mismatch key
- keyn: <class \'str\'> # the last key name of mismatch key
## constraint
{constraint}
## action
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
"""
REVISE_TEMPLATE = """
## context
change the nodes_output key's value to meet its comment and no need to add extra comment.
### nodes_output
{nodes_output}
-----
## format example
{example}
## nodes: "<node>: <type> # <instruction>"
{instruction}
## constraint
{constraint}
## action
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
"""
def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"):
markdown_str = ""
@ -105,6 +171,9 @@ class ActionNode:
"""增加子ActionNode"""
self.children[node.key] = node
def get_child(self, key: str) -> Union["ActionNode", None]:
return self.children.get(key, None)
def add_children(self, nodes: List["ActionNode"]):
"""批量增加子ActionNode"""
for node in nodes:
@ -117,11 +186,27 @@ class ActionNode:
obj.add_children(nodes)
return obj
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
def get_children_mapping_old(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""获得子ActionNode的字典以key索引"""
exclude = exclude or []
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""获得子ActionNode的字典以key索引支持多级结构"""
exclude = exclude or []
mapping = {}
def _get_mapping(node: "ActionNode", prefix: str = ""):
for key, child in node.children.items():
if key in exclude:
continue
full_key = f"{prefix}{key}"
mapping[full_key] = (child.expected_type, ...)
_get_mapping(child, prefix=f"{full_key}.")
_get_mapping(self)
return mapping
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
"""get self key: type mapping"""
return {self.key: (self.expected_type, ...)}
@ -133,6 +218,7 @@ class ActionNode:
return {} if exclude and self.key in exclude else self.get_self_mapping()
@classmethod
@register_action_outcls
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
"""基于pydantic v1的模型动态生成用来检验结果类型正确性"""
@ -152,6 +238,11 @@ class ActionNode:
new_class = create_model(class_name, __validators__=validators, **mapping)
return new_class
def create_class(self, mode: str = "auto", class_name: str = None, exclude=None):
class_name = class_name if class_name else f"{self.key}_AN"
mapping = self.get_mapping(mode=mode, exclude=exclude)
return self.create_model_class(class_name, mapping)
def create_children_class(self, exclude=None):
"""使用object内有的字段直接生成model_class"""
class_name = f"{self.key}_AN"
@ -186,6 +277,25 @@ class ActionNode:
return node_dict
def update_instruct_content(self, incre_data: dict[str, Any]):
assert self.instruct_content
origin_sc_dict = self.instruct_content.model_dump()
origin_sc_dict.update(incre_data)
output_class = self.create_class()
self.instruct_content = output_class(**origin_sc_dict)
def keys(self, mode: str = "auto") -> list:
if mode == "children" or (mode == "auto" and self.children):
keys = []
else:
keys = [self.key]
if mode == "root":
return keys
for _, child_node in self.children.items():
keys.append(child_node.key)
return keys
def compile_to(self, i: Dict, schema, kv_sep) -> str:
if schema == "json":
return json.dumps(i, indent=4)
@ -262,7 +372,7 @@ class ActionNode:
output_data_mapping: dict,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=CONFIG.timeout,
timeout=3,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
@ -294,7 +404,7 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
async def simple_fill(self, schema, mode, timeout=3, exclude=None):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
@ -309,7 +419,7 @@ class ActionNode:
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
@ -343,7 +453,241 @@ class ActionNode:
if exclude and i.key in exclude:
continue
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
tmp.update(child.instruct_content.dict())
tmp.update(child.instruct_content.model_dump())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)
return self
async def human_review(self) -> dict[str, str]:
review_comments = HumanInteraction().interact_with_instruct_content(
instruct_content=self.instruct_content, interact_type="review"
)
return review_comments
def _makeup_nodes_output_with_req(self) -> dict[str, str]:
instruct_content_dict = self.instruct_content.model_dump()
nodes_output = {}
for key, value in instruct_content_dict.items():
child = self.get_child(key)
nodes_output[key] = {"value": value, "requirement": child.instruction if child else self.instruction}
return nodes_output
async def auto_review(self, template: str = REVIEW_TEMPLATE) -> dict[str, str]:
"""use key's output value and its instruction to review the modification comment"""
nodes_output = self._makeup_nodes_output_with_req()
"""nodes_output format:
{
"key": {"value": "output value", "requirement": "key instruction"}
}
"""
if not nodes_output:
return dict()
prompt = template.format(
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
tag=TAG,
constraint=FORMAT_CONSTRAINT,
prompt_schema="json",
)
content = await self.llm.aask(prompt)
# Extract the dict of mismatch key and its comment. Due to the mismatch keys are unknown, here use the keys
# of ActionNode to judge if exist in `content` and then follow the `data_mapping` method to create model class.
keys = self.keys()
include_keys = []
for key in keys:
if f'"{key}":' in content:
include_keys.append(key)
if not include_keys:
return dict()
exclude_keys = list(set(keys).difference(include_keys))
output_class_name = f"{self.key}_AN_REVIEW"
output_class = self.create_class(class_name=output_class_name, exclude=exclude_keys)
parsed_data = llm_output_postprocess(
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
)
instruct_content = output_class(**parsed_data)
return instruct_content.model_dump()
async def simple_review(self, review_mode: ReviewMode = ReviewMode.AUTO):
# generate review comments
if review_mode == ReviewMode.HUMAN:
review_comments = await self.human_review()
else:
review_comments = await self.auto_review()
if not review_comments:
logger.warning("There are no review comments")
return review_comments
async def review(self, strgy: str = "simple", review_mode: ReviewMode = ReviewMode.AUTO):
"""only give the review comment of each exist and mismatch key
:param strgy: simple/complex
- simple: run only once
- complex: run each node
"""
if not hasattr(self, "llm"):
raise RuntimeError("use `review` after `fill`")
assert review_mode in ReviewMode
assert self.instruct_content, 'review only support with `schema != "raw"`'
if strgy == "simple":
review_comments = await self.simple_review(review_mode)
elif strgy == "complex":
# review each child node one-by-one
review_comments = {}
for _, child in self.children.items():
child_review_comment = await child.simple_review(review_mode)
review_comments.update(child_review_comment)
return review_comments
async def human_revise(self) -> dict[str, str]:
review_contents = HumanInteraction().interact_with_instruct_content(
instruct_content=self.instruct_content, mapping=self.get_mapping(mode="auto"), interact_type="revise"
)
# re-fill the ActionNode
self.update_instruct_content(review_contents)
return review_contents
def _makeup_nodes_output_with_comment(self, review_comments: dict[str, str]) -> dict[str, str]:
instruct_content_dict = self.instruct_content.model_dump()
nodes_output = {}
for key, value in instruct_content_dict.items():
if key in review_comments:
nodes_output[key] = {"value": value, "comment": review_comments[key]}
return nodes_output
async def auto_revise(
self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE
) -> dict[str, str]:
"""revise the value of incorrect keys"""
# generate review comments
if revise_mode == ReviseMode.AUTO:
review_comments: dict = await self.auto_review()
elif revise_mode == ReviseMode.HUMAN_REVIEW:
review_comments: dict = await self.human_review()
include_keys = list(review_comments.keys())
# generate revise content, two-steps
# step1, find the needed revise keys from review comments to makeup prompt template
nodes_output = self._makeup_nodes_output_with_comment(review_comments)
keys = self.keys()
exclude_keys = list(set(keys).difference(include_keys))
example = self.compile_example(schema="json", mode="auto", tag=TAG, exclude=exclude_keys)
instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys)
prompt = template.format(
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
example=example,
instruction=instruction,
constraint=FORMAT_CONSTRAINT,
prompt_schema="json",
)
# step2, use `_aask_v1` to get revise structure result
output_mapping = self.get_mapping(mode="auto", exclude=exclude_keys)
output_class_name = f"{self.key}_AN_REVISE"
content, scontent = await self._aask_v1(
prompt=prompt, output_class_name=output_class_name, output_data_mapping=output_mapping, schema="json"
)
# re-fill the ActionNode
sc_dict = scontent.model_dump()
self.update_instruct_content(sc_dict)
return sc_dict
async def simple_revise(self, revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
if revise_mode == ReviseMode.HUMAN:
revise_contents = await self.human_revise()
else:
revise_contents = await self.auto_revise(revise_mode)
return revise_contents
async def revise(self, strgy: str = "simple", revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
"""revise the content of ActionNode and update the instruct_content
:param strgy: simple/complex
- simple: run only once
- complex: run each node
"""
if not hasattr(self, "llm"):
raise RuntimeError("use `revise` after `fill`")
assert revise_mode in ReviseMode
assert self.instruct_content, 'revise only support with `schema != "raw"`'
if strgy == "simple":
revise_contents = await self.simple_revise(revise_mode)
elif strgy == "complex":
# revise each child node one-by-one
revise_contents = {}
for _, child in self.children.items():
child_revise_content = await child.simple_revise(revise_mode)
revise_contents.update(child_revise_content)
self.update_instruct_content(revise_contents)
return revise_contents
@classmethod
def from_pydantic(cls, model: Type[BaseModel], key: str = None):
"""
Creates an ActionNode tree from a Pydantic model.
Args:
model (Type[BaseModel]): The Pydantic model to convert.
Returns:
ActionNode: The root node of the created ActionNode tree.
"""
key = key or model.__name__
root_node = cls(key=model.__name__, expected_type=Type[model], instruction="", example="")
for field_name, field_model in model.model_fields.items():
# Extracting field details
expected_type = field_model.annotation
instruction = field_model.description or ""
example = field_model.default
# Check if the field is a Pydantic model itself.
# Use isinstance to avoid typing.List, typing.Dict, etc. (they are instances of type, not subclasses)
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
# Recursively process the nested model
child_node = cls.from_pydantic(expected_type, key=field_name)
else:
child_node = cls(key=field_name, expected_type=expected_type, instruction=instruction, example=example)
root_node.add_child(child_node)
return root_node
class ToolUse(BaseModel):
tool_name: str = Field(default="a", description="tool name", examples=[])
class Task(BaseModel):
task_id: int = Field(default="1", description="task id", examples=[1, 2, 3])
name: str = Field(default="Get data from ...", description="task name", examples=[])
dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3])
tool: ToolUse = Field(default=ToolUse(), description="tool use", examples=[])
class Tasks(BaseModel):
tasks: List[Task] = Field(default=[], description="tasks", examples=[])
if __name__ == "__main__":
node = ActionNode.from_pydantic(Tasks)
print("Tasks")
print(Tasks.model_json_schema())
print("Task")
print(Task.model_json_schema())
print(node)
prompt = node.compile(context="")
node.create_children_class()
print(prompt)

View file

@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : registry to store Dynamic Model from ActionNode.create_model_class to keep it as same Class
# with same class name and mapping
from functools import wraps
action_outcls_registry = dict()
def register_action_outcls(func):
"""
Due to `create_model` return different Class even they have same class name and mapping.
In order to do a comparison, use outcls_id to identify same Class with same class name and field definition
"""
@wraps(func)
def decorater(*args, **kwargs):
"""
arr example
[<class 'metagpt.actions.action_node.ActionNode'>, 'test', {'field': (str, Ellipsis)}]
"""
arr = list(args) + list(kwargs.values())
"""
outcls_id example
"<class 'metagpt.actions.action_node.ActionNode'>_test_{'field': (str, Ellipsis)}"
"""
for idx, item in enumerate(arr):
if isinstance(item, dict):
arr[idx] = dict(sorted(item.items()))
outcls_id = "_".join([str(i) for i in arr])
# eliminate typing influence
outcls_id = outcls_id.replace("typing.List", "list").replace("typing.Dict", "dict")
if outcls_id in action_outcls_registry:
return action_outcls_registry[outcls_id]
out_cls = func(*args, **kwargs)
action_outcls_registry[outcls_id] = out_cls
return out_cls
return decorater

View file

@ -13,12 +13,9 @@ import re
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.common import CodeParser
from metagpt.utils.file_repository import FileRepository
PROMPT_TEMPLATE = """
NOTICE
@ -49,13 +46,10 @@ Now you should start rewriting the code:
class DebugError(Action):
name: str = "DebugError"
context: RunCodeContext = Field(default_factory=RunCodeContext)
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
async def run(self, *args, **kwargs) -> str:
output_doc = await FileRepository.get_file(
filename=self.context.output_filename, relative_path=TEST_OUTPUTS_FILE_REPO
)
output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename)
if not output_doc:
return ""
output_detail = RunCodeResult.loads(output_doc.content)
@ -64,15 +58,13 @@ class DebugError(Action):
if matches:
return ""
logger.info(f"Debug and rewrite {self.context.test_filename}")
code_doc = await FileRepository.get_file(
filename=self.context.code_filename, relative_path=CONFIG.src_workspace
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get(
filename=self.i_context.code_filename
)
if not code_doc:
return ""
test_doc = await FileRepository.get_file(
filename=self.context.test_filename, relative_path=TEST_CODES_FILE_REPO
)
test_doc = await self.repo.tests.get(filename=self.i_context.test_filename)
if not test_doc:
return ""
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)

View file

@ -14,18 +14,17 @@ from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.actions.design_api_an import DESIGN_API_NODE
from metagpt.config import CONFIG
from metagpt.const import (
DATA_API_DESIGN_FILE_REPO,
PRDS_FILE_REPO,
SEQ_FLOW_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
SYSTEM_DESIGN_PDF_FILE_REPO,
from metagpt.actions.design_api_an import (
DATA_STRUCTURES_AND_INTERFACES,
DESIGN_API_NODE,
PROGRAM_CALL_FLOW,
REFINED_DATA_STRUCTURES_AND_INTERFACES,
REFINED_DESIGN_NODE,
REFINED_PROGRAM_CALL_FLOW,
)
from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Document, Documents, Message
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.mermaid import mermaid_to_file
NEW_REQ_TEMPLATE = """
@ -39,36 +38,30 @@ NEW_REQ_TEMPLATE = """
class WriteDesign(Action):
name: str = ""
context: Optional[str] = None
i_context: Optional[str] = None
desc: str = (
"Based on the PRD, think about the system design, and design the corresponding APIs, "
"data structures, library tables, processes, and paths. Please provide your design, feedback "
"clearly and in detail."
)
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
changed_prds = prds_file_repo.changed_files
async def run(self, with_messages: Message, schema: str = None):
# Use `git status` to identify which PRD documents have been modified in the `docs/prd` directory.
changed_prds = self.repo.docs.prd.changed_files
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
# changes.
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
changed_system_designs = self.repo.docs.system_design.changed_files
# For those PRDs and design documents that have undergone changes, regenerate the design content.
changed_files = Documents()
for filename in changed_prds.keys():
doc = await self._update_system_design(
filename=filename, prds_file_repo=prds_file_repo, system_design_file_repo=system_design_file_repo
)
doc = await self._update_system_design(filename=filename)
changed_files.docs[filename] = doc
for filename in changed_system_designs.keys():
if filename in changed_files.docs:
continue
doc = await self._update_system_design(
filename=filename, prds_file_repo=prds_file_repo, system_design_file_repo=system_design_file_repo
)
doc = await self._update_system_design(filename=filename)
changed_files.docs[filename] = doc
if not changed_files.docs:
logger.info("Nothing has changed.")
@ -76,61 +69,52 @@ class WriteDesign(Action):
# leaving room for global optimization in subsequent steps.
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
async def _new_system_design(self, context):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm)
return node
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
async def _merge(self, prd_doc, system_design_doc):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm)
system_design_doc.content = node.instruct_content.model_dump_json()
return system_design_doc
async def _update_system_design(self, filename, prds_file_repo, system_design_file_repo) -> Document:
prd = await prds_file_repo.get(filename)
old_system_design_doc = await system_design_file_repo.get(filename)
async def _update_system_design(self, filename) -> Document:
prd = await self.repo.docs.prd.get(filename)
old_system_design_doc = await self.repo.docs.system_design.get(filename)
if not old_system_design_doc:
system_design = await self._new_system_design(context=prd.content)
doc = Document(
root_path=SYSTEM_DESIGN_FILE_REPO,
doc = await self.repo.docs.system_design.save(
filename=filename,
content=system_design.instruct_content.model_dump_json(),
dependencies={prd.root_relative_path},
)
else:
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
await system_design_file_repo.save(
filename=filename, content=doc.content, dependencies={prd.root_relative_path}
)
await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self._save_data_api_design(doc)
await self._save_seq_flow(doc)
await self._save_pdf(doc)
await self.repo.resources.system_design.save_pdf(doc=doc)
return doc
@staticmethod
async def _save_data_api_design(design_doc):
async def _save_data_api_design(self, design_doc):
m = json.loads(design_doc.content)
data_api_design = m.get("Data structures and interfaces")
data_api_design = m.get(DATA_STRUCTURES_AND_INTERFACES.key) or m.get(REFINED_DATA_STRUCTURES_AND_INTERFACES.key)
if not data_api_design:
return
pathname = CONFIG.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
await WriteDesign._save_mermaid_file(data_api_design, pathname)
pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(data_api_design, pathname)
logger.info(f"Save class view to {str(pathname)}")
@staticmethod
async def _save_seq_flow(design_doc):
async def _save_seq_flow(self, design_doc):
m = json.loads(design_doc.content)
seq_flow = m.get("Program call flow")
seq_flow = m.get(PROGRAM_CALL_FLOW.key) or m.get(REFINED_PROGRAM_CALL_FLOW.key)
if not seq_flow:
return
pathname = CONFIG.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
await WriteDesign._save_mermaid_file(seq_flow, pathname)
pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(seq_flow, pathname)
logger.info(f"Saving sequence flow to {str(pathname)}")
@staticmethod
async def _save_pdf(design_doc):
await FileRepository.save_as(doc=design_doc, with_suffix=".md", relative_path=SYSTEM_DESIGN_PDF_FILE_REPO)
@staticmethod
async def _save_mermaid_file(data: str, pathname: Path):
async def _save_mermaid_file(self, data: str, pathname: Path):
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(data, pathname)
await mermaid_to_file(self.config.mermaid_engine, data, pathname)

View file

@ -8,6 +8,7 @@
from typing import List
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
from metagpt.utils.mermaid import MMC1, MMC2
IMPLEMENTATION_APPROACH = ActionNode(
@ -17,6 +18,15 @@ IMPLEMENTATION_APPROACH = ActionNode(
example="We will ...",
)
REFINED_IMPLEMENTATION_APPROACH = ActionNode(
key="Refined Implementation Approach",
expected_type=str,
instruction="Update and extend the original implementation approach to reflect the evolving challenges and "
"requirements due to incremental development. Outline the steps involved in the implementation process with the "
"detailed strategies.",
example="We will refine ...",
)
PROJECT_NAME = ActionNode(
key="Project name", expected_type=str, instruction="The project name with underline", example="game_2048"
)
@ -28,6 +38,14 @@ FILE_LIST = ActionNode(
example=["main.py", "game.py"],
)
REFINED_FILE_LIST = ActionNode(
key="Refined File list",
expected_type=List[str],
instruction="Update and expand the original file list including only relative paths. Up to 2 files can be added."
"Ensure that the refined file list reflects the evolving structure of the project.",
example=["main.py", "game.py", "new_feature.py"],
)
DATA_STRUCTURES_AND_INTERFACES = ActionNode(
key="Data structures and interfaces",
expected_type=str,
@ -37,6 +55,16 @@ DATA_STRUCTURES_AND_INTERFACES = ActionNode(
example=MMC1,
)
REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode(
key="Refined Data structures and interfaces",
expected_type=str,
instruction="Update and extend the existing mermaid classDiagram code syntax to incorporate new classes, "
"methods (including __init__), and functions with precise type annotations. Delineate additional "
"relationships between classes, ensuring clarity and adherence to PEP8 standards."
"Retain content that is not related to incremental development but important for consistency and clarity.",
example=MMC1,
)
PROGRAM_CALL_FLOW = ActionNode(
key="Program call flow",
expected_type=str,
@ -45,6 +73,16 @@ PROGRAM_CALL_FLOW = ActionNode(
example=MMC2,
)
REFINED_PROGRAM_CALL_FLOW = ActionNode(
key="Refined Program call flow",
expected_type=str,
instruction="Extend the existing sequenceDiagram code syntax with detailed information, accurately covering the"
"CRUD and initialization of each object. Ensure correct syntax usage and reflect the incremental changes introduced"
"in the classes and API defined above. "
"Retain content that is not related to incremental development but important for consistency and clarity.",
example=MMC2,
)
ANYTHING_UNCLEAR = ActionNode(
key="Anything UNCLEAR",
expected_type=str,
@ -61,4 +99,24 @@ NODES = [
ANYTHING_UNCLEAR,
]
REFINED_NODES = [
REFINED_IMPLEMENTATION_APPROACH,
REFINED_FILE_LIST,
REFINED_DATA_STRUCTURES_AND_INTERFACES,
REFINED_PROGRAM_CALL_FLOW,
ANYTHING_UNCLEAR,
]
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
REFINED_DESIGN_NODE = ActionNode.from_children("RefinedDesignAPI", REFINED_NODES)
def main():
prompt = DESIGN_API_NODE.compile(context="")
logger.info(prompt)
prompt = REFINED_DESIGN_NODE.compile(context="")
logger.info(prompt)
if __name__ == "__main__":
main()

View file

@ -13,7 +13,7 @@ from metagpt.actions.action import Action
class DesignReview(Action):
name: str = "DesignReview"
context: Optional[str] = None
i_context: Optional[str] = None
async def run(self, prd, api_design):
prompt = (

View file

@ -13,7 +13,7 @@ from metagpt.schema import Message
class ExecuteTask(Action):
name: str = "ExecuteTask"
context: list[Message] = []
i_context: list[Message] = []
async def run(self, *args, **kwargs):
pass

View file

@ -1,8 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/12 17:45
@Author : fisherdeng
@File : generate_questions.py
"""
from metagpt.actions import Action
@ -23,5 +21,5 @@ class GenerateQuestions(Action):
name: str = "GenerateQuestions"
async def run(self, context):
async def run(self, context) -> ActionNode:
return await QUESTIONS.fill(context=context, llm=self.llm)

View file

@ -16,17 +16,14 @@ from typing import Optional
import pandas as pd
from paddleocr import PaddleOCR
from pydantic import Field
from metagpt.actions import Action
from metagpt.const import INVOICE_OCR_TABLE_PATH
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.prompts.invoice_ocr import (
EXTRACT_OCR_MAIN_INFO_PROMPT,
REPLY_OCR_QUESTION_PROMPT,
)
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import OutputParser
from metagpt.utils.file import File
@ -41,7 +38,7 @@ class InvoiceOCR(Action):
"""
name: str = "InvoiceOCR"
context: Optional[str] = None
i_context: Optional[str] = None
@staticmethod
async def _check_file_type(file_path: Path) -> str:
@ -132,8 +129,7 @@ class GenerateTable(Action):
"""
name: str = "GenerateTable"
context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
i_context: Optional[str] = None
language: str = "ch"
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
@ -176,9 +172,6 @@ class ReplyQuestion(Action):
"""
name: str = "ReplyQuestion"
context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
language: str = "ch"
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:

View file

@ -12,39 +12,41 @@ from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.config import CONFIG
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.schema import Document
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class PrepareDocuments(Action):
"""PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt."""
name: str = "PrepareDocuments"
context: Optional[str] = None
i_context: Optional[str] = None
@property
def config(self):
return self.context.config
def _init_repo(self):
"""Initialize the Git environment."""
if not CONFIG.project_path:
name = CONFIG.project_name or FileRepository.new_filename()
path = Path(CONFIG.workspace_path) / name
if not self.config.project_path:
name = self.config.project_name or FileRepository.new_filename()
path = Path(self.config.workspace.path) / name
else:
path = Path(CONFIG.project_path)
if path.exists() and not CONFIG.inc:
path = Path(self.config.project_path)
if path.exists() and not self.config.inc:
shutil.rmtree(path)
CONFIG.project_path = path
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
self.config.project_path = path
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.repo = ProjectRepo(self.context.git_repo)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""
self._init_repo()
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=doc.content, relative_path=DOCS_FILE_REPO)
doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
# Send a Message notification to the WritePRD action, instructing it to process requirements using
# `docs/requirement.txt` and `docs/prds/`.
# `docs/requirement.txt` and `docs/prd/`.
return ActionOutput(content=doc.content, instruct_content=doc)

View file

@ -13,23 +13,16 @@
import json
from typing import Optional
from metagpt.actions import ActionOutput
from metagpt.actions.action import Action
from metagpt.actions.project_management_an import PM_NODE
from metagpt.config import CONFIG
from metagpt.const import (
PACKAGE_REQUIREMENTS_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
TASK_PDF_FILE_REPO,
)
from metagpt.actions.action_output import ActionOutput
from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
from metagpt.logs import logger
from metagpt.schema import Document, Documents
from metagpt.utils.file_repository import FileRepository
NEW_REQ_TEMPLATE = """
### Legacy Content
{old_tasks}
{old_task}
### New Requirements
{context}
@ -38,30 +31,23 @@ NEW_REQ_TEMPLATE = """
class WriteTasks(Action):
name: str = "CreateTasks"
context: Optional[str] = None
i_context: Optional[str] = None
async def run(self, with_messages, schema=CONFIG.prompt_schema):
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
tasks_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
changed_tasks = tasks_file_repo.changed_files
async def run(self, with_messages):
changed_system_designs = self.repo.docs.system_design.changed_files
changed_tasks = self.repo.docs.task.changed_files
change_files = Documents()
# Rewrite the system designs that have undergone changes based on the git head diff under
# `docs/system_designs/`.
for filename in changed_system_designs:
task_doc = await self._update_tasks(
filename=filename, system_design_file_repo=system_design_file_repo, tasks_file_repo=tasks_file_repo
)
task_doc = await self._update_tasks(filename=filename)
change_files.docs[filename] = task_doc
# Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`.
for filename in changed_tasks:
if filename in change_files.docs:
continue
task_doc = await self._update_tasks(
filename=filename, system_design_file_repo=system_design_file_repo, tasks_file_repo=tasks_file_repo
)
task_doc = await self._update_tasks(filename=filename)
change_files.docs[filename] = task_doc
if not change_files.docs:
@ -70,39 +56,36 @@ class WriteTasks(Action):
# global optimization in subsequent steps.
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo):
system_design_doc = await system_design_file_repo.get(filename)
task_doc = await tasks_file_repo.get(filename)
async def _update_tasks(self, filename):
system_design_doc = await self.repo.docs.system_design.get(filename)
task_doc = await self.repo.docs.task.get(filename)
if task_doc:
task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc)
await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path})
else:
rsp = await self._run_new_tasks(context=system_design_doc.content)
task_doc = Document(
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.model_dump_json()
task_doc = await self.repo.docs.task.save(
filename=filename,
content=rsp.instruct_content.model_dump_json(),
dependencies={system_design_doc.root_relative_path},
)
await tasks_file_repo.save(
filename=filename, content=task_doc.content, dependencies={system_design_doc.root_relative_path}
)
await self._update_requirements(task_doc)
await self._save_pdf(task_doc=task_doc)
return task_doc
async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema):
node = await PM_NODE.fill(context, self.llm, schema)
async def _run_new_tasks(self, context):
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
return node
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
node = await PM_NODE.fill(context, self.llm, schema)
async def _merge(self, system_design_doc, task_doc) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content)
node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
task_doc.content = node.instruct_content.model_dump_json()
return task_doc
@staticmethod
async def _update_requirements(doc):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python third-party packages", set()))
file_repo = CONFIG.git_repo.new_file_repository()
requirement_doc = await file_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
packages = set(m.get("Required Python packages", set()))
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
lines = requirement_doc.content.splitlines()
@ -110,8 +93,4 @@ class WriteTasks(Action):
if pkg == "":
continue
packages.add(pkg)
await file_repo.save(PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
@staticmethod
async def _save_pdf(task_doc):
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))

View file

@ -35,6 +35,20 @@ LOGIC_ANALYSIS = ActionNode(
],
)
REFINED_LOGIC_ANALYSIS = ActionNode(
key="Refined Logic Analysis",
expected_type=List[List[str]],
instruction="Review and refine the logic analysis by merging the Legacy Content and Incremental Content. "
"Provide a comprehensive list of files with classes/methods/functions to be implemented or modified incrementally. "
"Include dependency analysis, consider potential impacts on existing code, and document necessary imports.",
example=[
["game.py", "Contains Game class and ... functions"],
["main.py", "Contains main function, from game import Game"],
["new_feature.py", "Introduces NewFeature class and related functions"],
["utils.py", "Modifies existing utility functions to support incremental changes"],
],
)
TASK_LIST = ActionNode(
key="Task list",
expected_type=List[str],
@ -42,6 +56,15 @@ TASK_LIST = ActionNode(
example=["game.py", "main.py"],
)
REFINED_TASK_LIST = ActionNode(
key="Refined Task list",
expected_type=List[str],
instruction="Review and refine the combined task list after the merger of Legacy Content and Incremental Content, "
"and consistent with Refined File List. Ensure that tasks are organized in a logical and prioritized order, "
"considering dependencies for a streamlined and efficient development process. ",
example=["new_feature.py", "utils", "game.py", "main.py"],
)
FULL_API_SPEC = ActionNode(
key="Full API spec",
expected_type=str,
@ -54,9 +77,19 @@ SHARED_KNOWLEDGE = ActionNode(
key="Shared Knowledge",
expected_type=str,
instruction="Detail any shared knowledge, like common utility functions or configuration variables.",
example="'game.py' contains functions shared across the project.",
example="`game.py` contains functions shared across the project.",
)
REFINED_SHARED_KNOWLEDGE = ActionNode(
key="Refined Shared Knowledge",
expected_type=str,
instruction="Update and expand shared knowledge to reflect any new elements introduced. This includes common "
"utility functions, configuration variables for team collaboration. Retain content that is not related to "
"incremental development but important for consistency and clarity.",
example="`new_module.py` enhances shared utility functions for improved code reusability and collaboration.",
)
ANYTHING_UNCLEAR_PM = ActionNode(
key="Anything UNCLEAR",
expected_type=str,
@ -74,13 +107,25 @@ NODES = [
ANYTHING_UNCLEAR_PM,
]
REFINED_NODES = [
REQUIRED_PYTHON_PACKAGES,
REQUIRED_OTHER_LANGUAGE_PACKAGES,
REFINED_LOGIC_ANALYSIS,
REFINED_TASK_LIST,
FULL_API_SPEC,
REFINED_SHARED_KNOWLEDGE,
ANYTHING_UNCLEAR_PM,
]
PM_NODE = ActionNode.from_children("PM_NODE", NODES)
REFINED_PM_NODE = ActionNode.from_children("REFINED_PM_NODE", REFINED_NODES)
def main():
prompt = PM_NODE.compile(context="")
logger.info(prompt)
prompt = REFINED_PM_NODE.compile(context="")
logger.info(prompt)
if __name__ == "__main__":

View file

@ -12,7 +12,7 @@ from pathlib import Path
import aiofiles
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import (
AGGREGATION,
COMPOSITION,
@ -29,16 +29,16 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.context))
repo_parser = RepoParser(base_directory=Path(self.i_context))
# use pylint
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.context))
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.context).resolve(), package_root=package_root)
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
for file_info in symbols:
# Align to the same root directory in accordance with `class_views`.
@ -48,9 +48,9 @@ class RebuildClassView(Action):
await graph_db.save()
async def _create_mermaid_class_views(self, graph_db):
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / CONFIG.git_repo.workdir.name
pathname = path / self.context.git_repo.workdir.name
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)

View file

@ -12,7 +12,7 @@ from pathlib import Path
from typing import List
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.utils.common import aread, list_files
@ -21,8 +21,8 @@ from metagpt.utils.graph_repository import GraphKeyword
class RebuildSequenceView(Action):
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await RebuildSequenceView._search_main_entry(graph_db)
for entry in entries:
@ -41,7 +41,9 @@ class RebuildSequenceView(Action):
async def _rebuild_sequence_view(self, entry, graph_db):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.context, pathname=filename)
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
data = await self.llm.aask(

View file

@ -8,10 +8,8 @@ from typing import Callable, Optional, Union
from pydantic import Field, parse_obj_as
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.llm import LLM
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.tools.search_engine import SearchEngine
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
from metagpt.utils.common import OutputParser
@ -81,7 +79,7 @@ class CollectLinks(Action):
"""Action class to collect links from a search engine."""
name: str = "CollectLinks"
context: Optional[str] = None
i_context: Optional[str] = None
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
@ -129,8 +127,8 @@ class CollectLinks(Action):
if len(remove) == 0:
break
model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum())
prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp)
model_name = config.get_openai_llm().model
prompt = reduce_message_length(gen_msg(), model_name, system_text, 4096)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:
@ -177,19 +175,16 @@ class WebBrowseAndSummarize(Action):
"""Action class to explore the web and provide summaries of articles and webpages."""
name: str = "WebBrowseAndSummarize"
context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
i_context: Optional[str] = None
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: Optional[WebBrowserEngine] = None
web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT
def __init__(self, **kwargs):
super().__init__(**kwargs)
if CONFIG.model_for_researcher_summary:
self.llm.model = CONFIG.model_for_researcher_summary
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
run_func=self.browse_func,
)
@ -220,9 +215,7 @@ class WebBrowseAndSummarize(Action):
for u, content in zip([url, *urls], contents):
content = content.inner_text
chunk_summaries = []
for prompt in generate_prompt_chunk(
content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp
):
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096):
logger.debug(prompt)
summary = await self._aask(prompt, [system_text])
if summary == "Not relevant.":
@ -247,14 +240,8 @@ class WebBrowseAndSummarize(Action):
class ConductResearch(Action):
"""Action class to conduct research and generate a research report."""
name: str = "ConductResearch"
context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
def __init__(self, **kwargs):
super().__init__(**kwargs)
if CONFIG.model_for_researcher_report:
self.llm.model = CONFIG.model_for_researcher_report
async def run(
self,

View file

@ -16,12 +16,12 @@
class.
"""
import subprocess
from pathlib import Path
from typing import Tuple
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.exceptions import handle_exception
@ -48,7 +48,7 @@ WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION.
You should fill in necessary instruction, status, send to, and finally return all content between the --- segment line.
"""
CONTEXT = """
TEMPLATE_CONTEXT = """
## Development Code File Name
{code_file_name}
## Development Code
@ -77,7 +77,7 @@ standard errors:
class RunCode(Action):
name: str = "RunCode"
context: RunCodeContext = Field(default_factory=RunCodeContext)
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
@classmethod
async def run_text(cls, code) -> Tuple[str, str]:
@ -89,13 +89,12 @@ class RunCode(Action):
return "", str(e)
return namespace.get("result", ""), ""
@classmethod
async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
working_directory = str(working_directory)
additional_python_paths = [str(path) for path in additional_python_paths]
# Copy the current environment variables
env = CONFIG.new_environ()
env = self.context.new_environ()
# Modify the PYTHONPATH environment variable
additional_python_paths = [working_directory] + additional_python_paths
@ -119,25 +118,25 @@ class RunCode(Action):
return stdout.decode("utf-8"), stderr.decode("utf-8")
async def run(self, *args, **kwargs) -> RunCodeResult:
logger.info(f"Running {' '.join(self.context.command)}")
if self.context.mode == "script":
logger.info(f"Running {' '.join(self.i_context.command)}")
if self.i_context.mode == "script":
outs, errs = await self.run_script(
command=self.context.command,
working_directory=self.context.working_directory,
additional_python_paths=self.context.additional_python_paths,
command=self.i_context.command,
working_directory=self.i_context.working_directory,
additional_python_paths=self.i_context.additional_python_paths,
)
elif self.context.mode == "text":
outs, errs = await self.run_text(code=self.context.code)
elif self.i_context.mode == "text":
outs, errs = await self.run_text(code=self.i_context.code)
logger.info(f"{outs=}")
logger.info(f"{errs=}")
context = CONTEXT.format(
code=self.context.code,
code_file_name=self.context.code_filename,
test_code=self.context.test_code,
test_file_name=self.context.test_filename,
command=" ".join(self.context.command),
context = TEMPLATE_CONTEXT.format(
code=self.i_context.code,
code_file_name=self.i_context.code_filename,
test_code=self.i_context.test_code,
test_file_name=self.i_context.test_filename,
command=" ".join(self.i_context.command),
outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
errs=errs[:10000], # truncate errors to avoid token overflow
)
@ -152,11 +151,23 @@ class RunCode(Action):
return subprocess.run(cmd, check=check, cwd=cwd, env=env)
@staticmethod
def _install_dependencies(working_directory, env):
def _install_requirements(working_directory, env):
file_path = Path(working_directory) / "requirements.txt"
if not file_path.exists():
return
if file_path.stat().st_size == 0:
return
install_command = ["python", "-m", "pip", "install", "-r", "requirements.txt"]
logger.info(" ".join(install_command))
RunCode._install_via_subprocess(install_command, check=True, cwd=working_directory, env=env)
@staticmethod
def _install_pytest(working_directory, env):
install_pytest_command = ["python", "-m", "pip", "install", "pytest"]
logger.info(" ".join(install_pytest_command))
RunCode._install_via_subprocess(install_pytest_command, check=True, cwd=working_directory, env=env)
@staticmethod
def _install_dependencies(working_directory, env):
RunCode._install_requirements(working_directory, env)
RunCode._install_pytest(working_directory, env)

View file

@ -8,10 +8,9 @@
from typing import Any, Optional
import pydantic
from pydantic import Field, model_validator
from pydantic import model_validator
from metagpt.actions import Action
from metagpt.config import CONFIG, Config
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
@ -103,32 +102,25 @@ You are a member of a professional butler team and will provide helpful suggesti
"""
# TOTEST
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None
config: None = Field(default_factory=Config)
engine: Optional[SearchEngineType] = CONFIG.search_engine
engine: Optional[SearchEngineType] = None
search_func: Optional[Any] = None
search_engine: SearchEngine = None
result: str = ""
@model_validator(mode="before")
@classmethod
def validate_engine_and_run_func(cls, values):
engine = values.get("engine")
search_func = values.get("search_func")
config = Config()
if engine is None:
engine = config.search_engine
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.engine is None:
self.engine = self.config.search_engine
try:
search_engine = SearchEngine(engine=engine, run_func=search_func)
search_engine = SearchEngine(engine=self.engine, run_func=self.search_func)
except pydantic.ValidationError:
search_engine = None
values["search_engine"] = search_engine
return values
self.search_engine = search_engine
return self
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
if self.search_engine is None:

View file

@ -29,9 +29,7 @@ class ArgumentsParingAction(Action):
@property
def prompt(self):
prompt = "You are a function parser. You can convert spoken words into function parameters.\n"
prompt += "\n---\n"
prompt += f"{self.skill.name} function parameters description:\n"
prompt = f"{self.skill.name} function parameters description:\n"
for k, v in self.skill.arguments.items():
prompt += f"parameter `{k}`: {v}\n"
prompt += "\n---\n"
@ -49,7 +47,10 @@ class ArgumentsParingAction(Action):
async def run(self, with_message=None, **kwargs) -> Message:
prompt = self.prompt
rsp = await self.llm.aask(msg=prompt, system_msgs=[])
rsp = await self.llm.aask(
msg=prompt,
system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."],
)
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self)

View file

@ -11,11 +11,8 @@ from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
from metagpt.utils.file_repository import FileRepository
PROMPT_TEMPLATE = """
NOTICE
@ -29,9 +26,9 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
{system_design}
```
-----
# Tasks
# Task
```text
{tasks}
{task}
```
-----
{code_blocks}
@ -90,10 +87,9 @@ flowchart TB
"""
# TOTEST
class SummarizeCode(Action):
name: str = "SummarizeCode"
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
async def summarize_code(self, prompt):
@ -101,20 +97,20 @@ class SummarizeCode(Action):
return code_rsp
async def run(self):
design_pathname = Path(self.context.design_filename)
design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
task_pathname = Path(self.context.task_filename)
task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
design_pathname = Path(self.i_context.design_filename)
design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name)
task_pathname = Path(self.i_context.task_filename)
task_doc = await self.repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs
code_blocks = []
for filename in self.context.codes_filenames:
for filename in self.i_context.codes_filenames:
code_doc = await src_file_repo.get(filename)
code_block = f"```python\n{code_doc.content}\n```\n-----"
code_blocks.append(code_block)
format_example = FORMAT_EXAMPLE
prompt = PROMPT_TEMPLATE.format(
system_design=design_doc.content,
tasks=task_doc.content,
task=task_doc.content,
code_blocks="\n".join(code_blocks),
format_example=format_example,
)

View file

@ -9,25 +9,31 @@
from typing import Optional
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.schema import Message
# TOTEST
class TalkAction(Action):
context: str
i_context: str
history_summary: str = ""
knowledge: str = ""
rsp: Optional[Message] = None
@property
def agent_description(self):
return self.context.kwargs.agent_description
@property
def language(self):
return self.context.kwargs.language or config.language
@property
def prompt(self):
prompt = ""
if CONFIG.agent_description:
if self.agent_description:
prompt = (
f"You are {CONFIG.agent_description}. Your responses should align with the role-play agreement, "
f"You are {self.agent_description}. Your responses should align with the role-play agreement, "
f"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
f"decline to answer without revealing your AI nature to preserve the character's image.\n\n"
)
@ -36,10 +42,10 @@ class TalkAction(Action):
prompt += (
"If the information is insufficient, you can search in the historical conversation or knowledge above.\n"
)
language = CONFIG.language or DEFAULT_LANGUAGE
language = self.language
prompt += (
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n "
f"{self.context}"
f"{self.i_context}"
)
logger.debug(f"PROMPT: {prompt}")
return prompt
@ -47,11 +53,11 @@ class TalkAction(Action):
@property
def prompt_gpt4(self):
kvs = {
"{role}": CONFIG.agent_description or "",
"{role}": self.agent_description or "",
"{history}": self.history_summary or "",
"{knowledge}": self.knowledge or "",
"{language}": CONFIG.language or DEFAULT_LANGUAGE,
"{ask}": self.context,
"{language}": self.language,
"{ask}": self.i_context,
}
prompt = TalkActionPrompt.FORMATION_LOOSE
for k, v in kvs.items():
@ -68,9 +74,9 @@ class TalkAction(Action):
@property
def aask_args(self):
language = CONFIG.language or DEFAULT_LANGUAGE
language = self.language
system_msgs = [
f"You are {CONFIG.agent_description}.",
f"You are {self.agent_description}.",
"Your responses should align with the role-play agreement, "
"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
"decline to answer without revealing your AI nature to preserve the character's image.",
@ -82,7 +88,7 @@ class TalkAction(Action):
format_msgs.append({"role": "assistant", "content": self.knowledge})
if self.history_summary:
format_msgs.append({"role": "assistant", "content": self.history_summary})
return self.context, format_msgs, system_msgs
return self.i_context, format_msgs, system_msgs
async def run(self, with_message=None, **kwargs) -> Message:
msg, format_msgs, system_msgs = self.aask_args

View file

@ -21,18 +21,17 @@ from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE
from metagpt.const import (
BUGFIX_FILENAME,
CODE_SUMMARIES_FILE_REPO,
DOCS_FILE_REPO,
TASK_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
CODE_PLAN_AND_CHANGE_FILENAME,
REQUIREMENT_FILENAME,
)
from metagpt.logs import logger
from metagpt.schema import CodingContext, Document, RunCodeResult
from metagpt.utils.common import CodeParser
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.project_repo import ProjectRepo
PROMPT_TEMPLATE = """
NOTICE
@ -44,8 +43,8 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
## Design
{design}
## Tasks
{tasks}
## Task
{task}
## Legacy Code
```Code
@ -87,7 +86,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
class WriteCode(Action):
name: str = "WriteCode"
context: Document = Field(default_factory=Document)
i_context: Document = Field(default_factory=Document)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def write_code(self, prompt) -> str:
@ -96,16 +95,15 @@ class WriteCode(Action):
return code
async def run(self, *args, **kwargs) -> CodingContext:
bug_feedback = await FileRepository.get_file(filename=BUGFIX_FILENAME, relative_path=DOCS_FILE_REPO)
coding_context = CodingContext.loads(self.context.content)
test_doc = await FileRepository.get_file(
filename="test_" + coding_context.filename + ".json", relative_path=TEST_OUTPUTS_FILE_REPO
)
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
coding_context = CodingContext.loads(self.i_context.content)
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
code_plan_and_change = code_plan_and_change_doc.content if code_plan_and_change_doc else ""
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
summary_doc = None
if coding_context.design_doc and coding_context.design_doc.filename:
summary_doc = await FileRepository.get_file(
filename=coding_context.design_doc.filename, relative_path=CODE_SUMMARIES_FILE_REPO
)
summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
logs = ""
if test_doc:
test_detail = RunCodeResult.loads(test_doc.content)
@ -113,42 +111,109 @@ class WriteCode(Action):
if bug_feedback:
code_context = coding_context.code_doc.content
elif code_plan_and_change:
code_context = await self.get_codes(
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True
)
else:
code_context = await self.get_codes(coding_context.task_doc, exclude=self.context.filename)
code_context = await self.get_codes(
coding_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
prompt = PROMPT_TEMPLATE.format(
design=coding_context.design_doc.content if coding_context.design_doc else "",
tasks=coding_context.task_doc.content if coding_context.task_doc else "",
code=code_context,
logs=logs,
feedback=bug_feedback.content if bug_feedback else "",
filename=self.context.filename,
summary_log=summary_doc.content if summary_doc else "",
)
if code_plan_and_change:
prompt = REFINED_TEMPLATE.format(
user_requirement=requirement_doc.content if requirement_doc else "",
code_plan_and_change=code_plan_and_change,
design=coding_context.design_doc.content if coding_context.design_doc else "",
task=coding_context.task_doc.content if coding_context.task_doc else "",
code=code_context,
logs=logs,
feedback=bug_feedback.content if bug_feedback else "",
filename=self.i_context.filename,
summary_log=summary_doc.content if summary_doc else "",
)
else:
prompt = PROMPT_TEMPLATE.format(
design=coding_context.design_doc.content if coding_context.design_doc else "",
task=coding_context.task_doc.content if coding_context.task_doc else "",
code=code_context,
logs=logs,
feedback=bug_feedback.content if bug_feedback else "",
filename=self.i_context.filename,
summary_log=summary_doc.content if summary_doc else "",
)
logger.info(f"Writing {coding_context.filename}..")
code = await self.write_code(prompt)
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
root_path = self.context.src_workspace if self.context.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
return coding_context
@staticmethod
async def get_codes(task_doc, exclude) -> str:
async def get_codes(task_doc: Document, exclude: str, project_repo: ProjectRepo, use_inc: bool = False) -> str:
"""
Get codes for generating the exclude file in various scenarios.
Attributes:
task_doc (Document): Document object of the task file.
exclude (str): The file to be generated. Specifies the filename to be excluded from the code snippets.
project_repo (ProjectRepo): ProjectRepo object of the project.
use_inc (bool): Indicates whether the scenario involves incremental development. Defaults to False.
Returns:
str: Codes for generating the exclude file.
"""
if not task_doc:
return ""
if not task_doc.content:
task_doc.content = FileRepository.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO)
task_doc = project_repo.docs.task.get(filename=task_doc.filename)
m = json.loads(task_doc.content)
code_filenames = m.get("Task list", [])
code_filenames = m.get(TASK_LIST.key, []) if use_inc else m.get(REFINED_TASK_LIST.key, [])
codes = []
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
for filename in code_filenames:
if filename == exclude:
continue
doc = await src_file_repo.get(filename=filename)
if not doc:
continue
codes.append(f"----- {filename}\n" + doc.content)
src_file_repo = project_repo.srcs
# Incremental development scenario
if use_inc:
src_files = src_file_repo.all_files
# Get the old workspace contained the old codes and old workspace are created in previous CodePlanAndChange
old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace)
old_files = old_file_repo.all_files
# Get the union of the files in the src and old workspaces
union_files_list = list(set(src_files) | set(old_files))
for filename in union_files_list:
# Exclude the current file from the all code snippets
if filename == exclude:
# If the file is in the old workspace, use the old code
# Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and
# essential functionality is included for the projects requirements
if filename in old_files and filename != "main.py":
# Use old code
doc = await old_file_repo.get(filename=filename)
# If the file is in the src workspace, skip it
else:
continue
codes.insert(0, f"-----Now, {filename} to be rewritten\n```{doc.content}```\n=====")
# The code snippets are generated from the src workspace
else:
doc = await src_file_repo.get(filename=filename)
# If the file does not exist in the src workspace, skip it
if not doc:
continue
codes.append(f"----- {filename}\n```{doc.content}```")
# Normal scenario
else:
for filename in code_filenames:
# Exclude the current file to get the code snippets for generating the current file
if filename == exclude:
continue
doc = await src_file_repo.get(filename=filename)
if not doc:
continue
codes.append(f"----- {filename}\n```{doc.content}```")
return "\n".join(codes)

View file

@ -5,7 +5,7 @@
@File : write_review.py
"""
import asyncio
from typing import List
from typing import List, Literal
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
@ -21,16 +21,15 @@ REVIEW = ActionNode(
],
)
LGTM = ActionNode(
key="LGTM",
expected_type=str,
instruction="LGTM/LBTM. If the code is fully implemented, "
"give a LGTM (Looks Good To Me), otherwise provide a LBTM (Looks Bad To Me).",
REVIEW_RESULT = ActionNode(
key="ReviewResult",
expected_type=Literal["LGTM", "LBTM"],
instruction="LGTM/LBTM. If the code is fully implemented, " "give a LGTM, otherwise provide a LBTM.",
example="LBTM",
)
ACTIONS = ActionNode(
key="Actions",
NEXT_STEPS = ActionNode(
key="NextSteps",
expected_type=str,
instruction="Based on the code review outcome, suggest actionable steps. This can include code changes, "
"refactoring suggestions, or any follow-up tasks.",
@ -69,7 +68,7 @@ WRITE_DRAFT = ActionNode(
)
WRITE_MOVE_FUNCTION = ActionNode(
WRITE_FUNCTION = ActionNode(
key="WriteFunction",
expected_type=str,
instruction="write code for the function not implemented.",
@ -555,8 +554,8 @@ LBTM
"""
WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM, ACTIONS])
WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_MOVE_FUNCTION])
WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, REVIEW_RESULT, NEXT_STEPS])
WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_FUNCTION])
CR_FOR_MOVE_FUNCTION_BY_3 = """
@ -579,8 +578,7 @@ class WriteCodeAN(Action):
async def run(self, context):
self.llm.system_prompt = "You are an outstanding engineer and can implement any code"
return await WRITE_MOVE_FUNCTION.fill(context=context, llm=self.llm, schema="json")
# return await WRITE_CODE_NODE.fill(context=context, llm=self.llm, schema="markdown")
return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json")
async def main():

View file

@ -0,0 +1,210 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/26
@Author : mannaandpoem
@File : write_code_plan_and_change_an.py
"""
import os
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.actions.action_node import ActionNode
from metagpt.schema import CodePlanAndChangeContext
CODE_PLAN_AND_CHANGE = ActionNode(
key="Code Plan And Change",
expected_type=str,
instruction="Developing comprehensive and step-by-step incremental development plan, and write Incremental "
"Change by making a code draft that how to implement incremental development including detailed steps based on the "
"context. Note: Track incremental changes using mark of '+' or '-' for add/modify/delete code, and conforms to the "
"output format of git diff",
example="""
1. Plan for calculator.py: Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, multiplication, and division. Additionally, implement robust error handling for the division operation to mitigate potential issues related to division by zero.
```python
class Calculator:
self.result = number1 + number2
return self.result
- def sub(self, number1, number2) -> float:
+ def subtract(self, number1: float, number2: float) -> float:
+ '''
+ Subtracts the second number from the first and returns the result.
+
+ Args:
+ number1 (float): The number to be subtracted from.
+ number2 (float): The number to subtract.
+
+ Returns:
+ float: The difference of number1 and number2.
+ '''
+ self.result = number1 - number2
+ return self.result
+
def multiply(self, number1: float, number2: float) -> float:
- pass
+ '''
+ Multiplies two numbers and returns the result.
+
+ Args:
+ number1 (float): The first number to multiply.
+ number2 (float): The second number to multiply.
+
+ Returns:
+ float: The product of number1 and number2.
+ '''
+ self.result = number1 * number2
+ return self.result
+
def divide(self, number1: float, number2: float) -> float:
- pass
+ '''
+ ValueError: If the second number is zero.
+ '''
+ if number2 == 0:
+ raise ValueError('Cannot divide by zero')
+ self.result = number1 / number2
+ return self.result
+
- def reset_result(self):
+ def clear(self):
+ if self.result != 0.0:
+ print("Result is not zero, clearing...")
+ else:
+ print("Result is already zero, no need to clear.")
+
self.result = 0.0
```
2. Plan for main.py: Integrate new API endpoints for subtraction, multiplication, and division into the existing codebase of `main.py`. Then, ensure seamless integration with the overall application architecture and maintain consistency with coding standards.
```python
def add_numbers():
result = calculator.add_numbers(num1, num2)
return jsonify({'result': result}), 200
-# TODO: Implement subtraction, multiplication, and division operations
+@app.route('/subtract_numbers', methods=['POST'])
+def subtract_numbers():
+ data = request.get_json()
+ num1 = data.get('num1', 0)
+ num2 = data.get('num2', 0)
+ result = calculator.subtract_numbers(num1, num2)
+ return jsonify({'result': result}), 200
+
+@app.route('/multiply_numbers', methods=['POST'])
+def multiply_numbers():
+ data = request.get_json()
+ num1 = data.get('num1', 0)
+ num2 = data.get('num2', 0)
+ try:
+ result = calculator.divide_numbers(num1, num2)
+ except ValueError as e:
+ return jsonify({'error': str(e)}), 400
+ return jsonify({'result': result}), 200
+
if __name__ == '__main__':
app.run()
```""",
)
CODE_PLAN_AND_CHANGE_CONTEXT = """
## User New Requirements
{requirement}
## PRD
{prd}
## Design
{design}
## Task
{task}
## Legacy Code
{code}
"""
REFINED_TEMPLATE = """
NOTICE
Role: You are a professional engineer; The main goal is to complete incremental development by combining legacy code and plan and Incremental Change, ensuring the integration of new features.
# Context
## User New Requirements
{user_requirement}
## Code Plan And Change
{code_plan_and_change}
## Design
{design}
## Task
{task}
## Legacy Code
```Code
{code}
```
## Debug logs
```text
{logs}
{summary_log}
```
## Bug Feedback logs
```text
{feedback}
```
# Format example
## Code: {filename}
```python
## {filename}
...
```
# Instruction: Based on the context, follow "Format example", write or rewrite code.
## Write/Rewrite Code: Only write one file {filename}, write or rewrite complete code using triple quotes based on the following attentions and context.
1. Only One file: do your best to implement THIS ONLY ONE FILE.
2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.
3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import.
4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.
5. Follow Code Plan And Change: If there is any Incremental Change that is marked by the git diff format using '+' and '-' for add/modify/delete code, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the plan.
6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
7. Before using a external variable/module, make sure you import it first.
8. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.
9. Attention: Retain details that are not related to incremental development but are important for maintaining the consistency and clarity of the old code.
"""
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", [CODE_PLAN_AND_CHANGE])
class WriteCodePlanAndChange(Action):
name: str = "WriteCodePlanAndChange"
i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext)
async def run(self, *args, **kwargs):
self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to "
"meticulously craft comprehensive incremental development plan and deliver detailed incremental change"
prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename)
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
code_text = await self.get_old_codes()
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
requirement=self.i_context.requirement,
prd=prd_doc.content,
design=design_doc.content,
task=task_doc.content,
code=code_text,
)
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
async def get_old_codes(self) -> str:
self.repo.old_workspace = self.repo.git_repo.workdir / os.path.basename(self.config.project_path)
old_file_repo = self.repo.git_repo.new_file_repository(relative_path=self.repo.old_workspace)
old_codes = await old_file_repo.get_all()
codes = [f"----- {code.filename}\n```{code.content}```" for code in old_codes]
return "\n".join(codes)

View file

@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import WriteCode
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import CODE_PLAN_AND_CHANGE_FILENAME, REQUIREMENT_FILENAME
from metagpt.logs import logger
from metagpt.schema import CodingContext
from metagpt.utils.common import CodeParser
@ -120,7 +120,7 @@ REWRITE_CODE_TEMPLATE = """
class WriteCodeReview(Action):
name: str = "WriteCodeReview"
context: CodingContext = Field(default_factory=CodingContext)
i_context: CodingContext = Field(default_factory=CodingContext)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
@ -136,41 +136,64 @@ class WriteCodeReview(Action):
return result, code
async def run(self, *args, **kwargs) -> CodingContext:
iterative_code = self.context.code_doc.content
k = CONFIG.code_review_k_times or 1
iterative_code = self.i_context.code_doc.content
k = self.context.config.code_review_k_times or 1
for i in range(k):
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
task_content = self.context.task_doc.content if self.context.task_doc else ""
code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename)
context = "\n".join(
[
"## System Design\n" + str(self.context.design_doc) + "\n",
"## Tasks\n" + task_content + "\n",
"## Code Files\n" + code_context + "\n",
]
format_example = FORMAT_EXAMPLE.format(filename=self.i_context.code_doc.filename)
task_content = self.i_context.task_doc.content if self.i_context.task_doc else ""
code_context = await WriteCode.get_codes(
self.i_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.repo.with_src_path(self.context.src_workspace),
use_inc=self.config.inc,
)
if not self.config.inc:
context = "\n".join(
[
"## System Design\n" + str(self.i_context.design_doc) + "\n",
"## Task\n" + task_content + "\n",
"## Code Files\n" + code_context + "\n",
]
)
else:
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
code_plan_and_change_doc = await self.repo.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
context = "\n".join(
[
"## User New Requirements\n" + str(requirement_doc) + "\n",
"## Code Plan And Change\n" + str(code_plan_and_change_doc) + "\n",
"## System Design\n" + str(self.i_context.design_doc) + "\n",
"## Task\n" + task_content + "\n",
"## Code Files\n" + code_context + "\n",
]
)
context_prompt = PROMPT_TEMPLATE.format(
context=context,
code=iterative_code,
filename=self.context.code_doc.filename,
filename=self.i_context.code_doc.filename,
)
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
format_example=format_example,
)
len1 = len(iterative_code) if iterative_code else 0
len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0
logger.info(
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
f"{len(self.context.code_doc.content)=}"
f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
f"len(self.i_context.code_doc.content)={len2}"
)
result, rewrited_code = await self.write_code_review_and_rewrite(
context_prompt, cr_prompt, self.context.code_doc.filename
context_prompt, cr_prompt, self.i_context.code_doc.filename
)
if "LBTM" in result:
iterative_code = rewrited_code
elif "LGTM" in result:
self.context.code_doc.content = iterative_code
return self.context
self.i_context.code_doc.content = iterative_code
return self.i_context
# code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING)
# self._save(context, filename, code)
# 如果rewrited_code是None原code perfect那么直接返回code
self.context.code_doc.content = iterative_code
return self.context
self.i_context.code_doc.content = iterative_code
return self.i_context

View file

@ -161,7 +161,7 @@ class WriteDocstring(Action):
"""
desc: str = "Write docstring for code."
context: Optional[str] = None
i_context: Optional[str] = None
async def run(
self,

View file

@ -14,26 +14,22 @@
from __future__ import annotations
import json
import uuid
from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.fix_bug import FixBug
from metagpt.actions.write_prd_an import (
COMPETITIVE_QUADRANT_CHART,
PROJECT_NAME,
REFINED_PRD_NODE,
WP_IS_RELATIVE_NODE,
WP_ISSUE_TYPE_NODE,
WRITE_PRD_NODE,
)
from metagpt.config import CONFIG
from metagpt.const import (
BUGFIX_FILENAME,
COMPETITIVE_ANALYSIS_FILE_REPO,
DOCS_FILE_REPO,
PRD_PDF_FILE_REPO,
PRDS_FILE_REPO,
REQUIREMENT_FILENAME,
)
from metagpt.logs import logger
@ -63,135 +59,114 @@ NEW_REQ_TEMPLATE = """
class WritePRD(Action):
name: str = "WritePRD"
content: Optional[str] = None
"""WritePRD deal with the following situations:
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
2. New requirement: If the requirement is a new requirement, the PRD document will be generated.
3. Requirement update: If the requirement is an update, the PRD document will be updated.
"""
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME)
if requirement_doc and await self._is_bugfix(requirement_doc.content):
await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
"""Run the action."""
req: Document = await self.repo.requirement
docs: list[Document] = await self.repo.docs.prd.get_all()
if not req:
raise FileNotFoundError("No requirement document found.")
if await self._is_bugfix(req.content):
logger.info(f"Bugfix detected: {req.content}")
return await self._handle_bugfix(req)
# remove bugfix file from last round in case of conflict
await self.repo.docs.delete(filename=BUGFIX_FILENAME)
# if requirement is related to other documents, update them, otherwise create a new one
if related_docs := await self.get_related_docs(req, docs):
logger.info(f"Requirement update detected: {req.content}")
return await self._handle_requirement_update(req, related_docs)
else:
await docs_file_repo.delete(filename=BUGFIX_FILENAME)
logger.info(f"New requirement detected: {req.content}")
return await self._handle_new_requirement(req)
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
prd_docs = await prds_file_repo.get_all()
change_files = Documents()
for prd_doc in prd_docs:
prd_doc = await self._update_prd(
requirement_doc=requirement_doc, prd_doc=prd_doc, prds_file_repo=prds_file_repo, *args, **kwargs
)
if not prd_doc:
continue
change_files.docs[prd_doc.filename] = prd_doc
logger.info(f"rewrite prd: {prd_doc.filename}")
# If there is no existing PRD, generate one using 'docs/requirement.txt'.
if not change_files.docs:
prd_doc = await self._update_prd(
requirement_doc=requirement_doc, prd_doc=None, prds_file_repo=prds_file_repo, *args, **kwargs
)
if prd_doc:
change_files.docs[prd_doc.filename] = prd_doc
logger.debug(f"new prd: {prd_doc.filename}")
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
# 'publish' message to transition the workflow to the next stage. This design allows room for global
# optimization in subsequent steps.
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _handle_bugfix(self, req: Document) -> Message:
# ... bugfix logic ...
await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content)
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
# sas = SearchAndSummarize()
# # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
# rsp = ""
# info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}"
# if sas.result:
# logger.info(sas.result)
# logger.info(rsp)
project_name = CONFIG.project_name or ""
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
async def _handle_new_requirement(self, req: Document) -> ActionOutput:
"""handle new requirement"""
project_name = self.project_name
context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
await self._rename_workspace(node)
return node
new_prd_doc = await self.repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json()
)
await self._save_competitive_analysis(new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return Documents.from_iterable(documents=[new_prd_doc]).to_action_output()
async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd_doc.content, requirements=new_requirement_doc.content)
async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput:
# ... requirement update logic ...
for doc in related_docs:
await self._update_prd(req, doc)
return Documents.from_iterable(documents=related_docs).to_action_output()
async def _is_bugfix(self, context: str) -> bool:
if not self.repo.code_files_exists():
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
"""get the related documents"""
# refine: use gather to speed up
return [i for i in docs if await self._is_related(req, i)]
async def _is_related(self, req: Document, old_prd: Document) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document:
if not CONFIG.project_name:
CONFIG.project_name = Path(CONFIG.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
prd_doc.content = node.instruct_content.model_dump_json()
async def _merge(self, req: Document, related_doc: Document) -> Document:
if not self.project_name:
self.project_name = Path(self.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
related_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return prd_doc
return related_doc
async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None:
if not prd_doc:
prd = await self._run_new_requirement(
requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs
)
new_prd_doc = Document(
root_path=PRDS_FILE_REPO,
filename=FileRepository.new_filename() + ".json",
content=prd.instruct_content.model_dump_json(),
)
elif await self._is_relative(requirement_doc, prd_doc):
new_prd_doc = await self._merge(requirement_doc, prd_doc)
else:
return None
await prds_file_repo.save(filename=new_prd_doc.filename, content=new_prd_doc.content)
async def _update_prd(self, req: Document, prd_doc: Document) -> Document:
new_prd_doc: Document = await self._merge(req, prd_doc)
await self.repo.docs.prd.save_doc(doc=new_prd_doc)
await self._save_competitive_analysis(new_prd_doc)
await self._save_pdf(new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return new_prd_doc
@staticmethod
async def _save_competitive_analysis(prd_doc):
async def _save_competitive_analysis(self, prd_doc: Document):
m = json.loads(prd_doc.content)
quadrant_chart = m.get("Competitive Quadrant Chart")
quadrant_chart = m.get(COMPETITIVE_QUADRANT_CHART.key)
if not quadrant_chart:
return
pathname = (
CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
)
if not pathname.parent.exists():
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(quadrant_chart, pathname)
pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
@staticmethod
async def _save_pdf(prd_doc):
await FileRepository.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO)
@staticmethod
async def _rename_workspace(prd):
if not CONFIG.project_name:
async def _rename_workspace(self, prd):
if not self.project_name:
if isinstance(prd, (ActionOutput, ActionNode)):
ws_name = prd.instruct_content.model_dump()["Project Name"]
else:
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
CONFIG.project_name = ws_name
if not CONFIG.project_name: # The LLM failed to provide a project name, and the user didn't provide one either.
CONFIG.project_name = "app" + uuid.uuid4().hex[:16]
CONFIG.git_repo.rename_root(CONFIG.project_name)
async def _is_bugfix(self, context) -> bool:
src_workspace_path = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
code_files = CONFIG.git_repo.get_files(relative_path=src_workspace_path)
if not code_files:
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
self.project_name = ws_name
self.repo.git_repo.rename_root(self.project_name)

View file

@ -8,7 +8,6 @@
from typing import List
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
LANGUAGE = ActionNode(
key="Language",
@ -31,10 +30,18 @@ ORIGINAL_REQUIREMENTS = ActionNode(
example="Create a 2048 game",
)
REFINED_REQUIREMENTS = ActionNode(
key="Refined Requirements",
expected_type=str,
instruction="Place the New user's original requirements here.",
example="Create a 2048 game with a new feature that ...",
)
PROJECT_NAME = ActionNode(
key="Project Name",
expected_type=str,
instruction="According to the content of \"Original Requirements,\" name the project using snake case style , like 'game_2048' or 'simple_crm.",
instruction='According to the content of "Original Requirements," name the project using snake case style , '
"like 'game_2048' or 'simple_crm.",
example="game_2048",
)
@ -45,6 +52,18 @@ PRODUCT_GOALS = ActionNode(
example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"],
)
REFINED_PRODUCT_GOALS = ActionNode(
key="Refined Product Goals",
expected_type=List[str],
instruction="Update and expand the original product goals to reflect the evolving needs due to incremental "
"development.Ensure that the refined goals align with the current project direction and contribute to its success.",
example=[
"Enhance user engagement through new features",
"Optimize performance for scalability",
"Integrate innovative UI enhancements",
],
)
USER_STORIES = ActionNode(
key="User Stories",
expected_type=List[str],
@ -58,6 +77,20 @@ USER_STORIES = ActionNode(
],
)
REFINED_USER_STORIES = ActionNode(
key="Refined User Stories",
expected_type=List[str],
instruction="Update and expand the original scenario-based user stories to reflect the evolving needs due to "
"incremental development. Ensure that the refined user stories capture incremental features and improvements. ",
example=[
"As a player, I want to choose difficulty levels to challenge my skills",
"As a player, I want a visually appealing score display after each game for a better gaming experience",
"As a player, I want a convenient restart button displayed when I lose to quickly start a new game",
"As a player, I want an enhanced and aesthetically pleasing UI to elevate the overall gaming experience",
"As a player, I want the ability to play the game seamlessly on my mobile phone for on-the-go entertainment",
],
)
COMPETITIVE_ANALYSIS = ActionNode(
key="Competitive Analysis",
expected_type=List[str],
@ -97,6 +130,15 @@ REQUIREMENT_ANALYSIS = ActionNode(
example="",
)
REFINED_REQUIREMENT_ANALYSIS = ActionNode(
key="Refined Requirement Analysis",
expected_type=List[str],
instruction="Review and refine the existing requirement analysis to align with the evolving needs of the project "
"due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements "
"required for the refined project scope.",
example=["Require add/update/modify ..."],
)
REQUIREMENT_POOL = ActionNode(
key="Requirement Pool",
expected_type=List[List[str]],
@ -104,6 +146,14 @@ REQUIREMENT_POOL = ActionNode(
example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]],
)
REFINED_REQUIREMENT_POOL = ActionNode(
key="Refined Requirement Pool",
expected_type=List[List[str]],
instruction="List down the top 5 to 7 requirements with their priority (P0, P1, P2). "
"Cover both legacy content and incremental content. Retain content unrelated to incremental development",
example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]],
)
UI_DESIGN_DRAFT = ActionNode(
key="UI Design draft",
expected_type=str,
@ -152,15 +202,22 @@ NODES = [
ANYTHING_UNCLEAR,
]
REFINED_NODES = [
LANGUAGE,
PROGRAMMING_LANGUAGE,
REFINED_REQUIREMENTS,
PROJECT_NAME,
REFINED_PRODUCT_GOALS,
REFINED_USER_STORIES,
COMPETITIVE_ANALYSIS,
COMPETITIVE_QUADRANT_CHART,
REFINED_REQUIREMENT_ANALYSIS,
REFINED_REQUIREMENT_POOL,
UI_DESIGN_DRAFT,
ANYTHING_UNCLEAR,
]
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES)
REFINED_PRD_NODE = ActionNode.from_children("RefinedPRD", REFINED_NODES)
WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON])
WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON])
def main():
prompt = WRITE_PRD_NODE.compile(context="")
logger.info(prompt)
if __name__ == "__main__":
main()

View file

@ -13,7 +13,7 @@ from metagpt.actions.action import Action
class WritePRDReview(Action):
name: str = ""
context: Optional[str] = None
i_context: Optional[str] = None
prd: Optional[str] = None
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"

View file

@ -8,14 +8,14 @@
from typing import Optional
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.context import Context
from metagpt.logs import logger
class WriteTeachingPlanPart(Action):
"""Write Teaching Plan Part"""
context: Optional[str] = None
i_context: Optional[str] = None
topic: str = ""
language: str = "Chinese"
rsp: Optional[str] = None
@ -24,7 +24,7 @@ class WriteTeachingPlanPart(Action):
statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, [])
statements = []
for p in statement_patterns:
s = self.format_value(p)
s = self.format_value(p, context=self.context)
statements.append(s)
formatter = (
TeachingPlanBlock.PROMPT_TITLE_TEMPLATE
@ -35,7 +35,7 @@ class WriteTeachingPlanPart(Action):
formation=TeachingPlanBlock.FORMATION,
role=self.prefix,
statements="\n".join(statements),
lesson=self.context,
lesson=self.i_context,
topic=self.topic,
language=self.language,
)
@ -68,20 +68,23 @@ class WriteTeachingPlanPart(Action):
return self.topic
@staticmethod
def format_value(value):
def format_value(value, context: Context):
"""Fill parameters inside `value` with `options`."""
if not isinstance(value, str):
return value
if "{" not in value:
return value
merged_opts = CONFIG.options or {}
options = context.config.model_dump()
for k, v in context.kwargs:
options[k] = v # None value is allowed to override and disable the value from config.
opts = {k: v for k, v in options.items() if v is not None}
try:
return value.format(**merged_opts)
return value.format(**opts)
except KeyError as e:
logger.warning(f"Parameter is missing:{e}")
for k, v in merged_opts.items():
for k, v in opts.items():
value = value.replace("{" + f"{k}" + "}", str(v))
return value

View file

@ -39,7 +39,7 @@ you should correctly import the necessary classes based on these file locations!
class WriteTest(Action):
name: str = "WriteTest"
context: Optional[TestingContext] = None
i_context: Optional[TestingContext] = None
async def write_code(self, prompt):
code_rsp = await self._aask(prompt)
@ -55,16 +55,16 @@ class WriteTest(Action):
return code
async def run(self, *args, **kwargs) -> TestingContext:
if not self.context.test_doc:
self.context.test_doc = Document(
filename="test_" + self.context.code_doc.filename, root_path=TEST_CODES_FILE_REPO
if not self.i_context.test_doc:
self.i_context.test_doc = Document(
filename="test_" + self.i_context.code_doc.filename, root_path=TEST_CODES_FILE_REPO
)
fake_root = "/data"
prompt = PROMPT_TEMPLATE.format(
code_to_test=self.context.code_doc.content,
test_file_name=self.context.test_doc.filename,
source_file_path=fake_root + "/" + self.context.code_doc.root_relative_path,
code_to_test=self.i_context.code_doc.content,
test_file_name=self.i_context.test_doc.filename,
source_file_path=fake_root + "/" + self.i_context.code_doc.root_relative_path,
workspace=fake_root,
)
self.context.test_doc.content = await self.write_code(prompt)
return self.context
self.i_context.test_doc.content = await self.write_code(prompt)
return self.i_context

View file

@ -1,290 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Provide configuration, singleton
@Modified By: mashenquan, 2023/11/27.
1. According to Section 2.2.3.11 of RFC 135, add git repository support.
2. Add the parameter `src_workspace` for the old version project path.
"""
import datetime
import json
import os
import warnings
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any
from uuid import uuid4
import yaml
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
from metagpt.logs import logger
from metagpt.tools import SearchEngineType, WebBrowserEngineType
from metagpt.utils.common import require_python_version
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.singleton import Singleton
class NotConfiguredException(Exception):
"""Exception raised for errors in the configuration.
Attributes:
message -- explanation of the error
"""
def __init__(self, message="The required configuration is not set"):
self.message = message
super().__init__(self.message)
class LLMProviderEnum(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
SPARK = "spark"
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE_OPENAI = "azure_openai"
OLLAMA = "ollama"
def __missing__(self, key):
return self.OPENAI
class Config(metaclass=Singleton):
"""
Regular usage method:
config = Config("config.yaml")
secret_key = config.get_key("MY_SECRET_KEY")
print("Secret key:", secret_key)
"""
_instance = None
home_yaml_file = Path.home() / ".metagpt/config.yaml"
key_yaml_file = METAGPT_ROOT / "config/key.yaml"
default_yaml_file = METAGPT_ROOT / "config/config.yaml"
def __init__(self, yaml_file=default_yaml_file, cost_data=""):
global_options = OPTIONS.get()
# cli paras
self.project_path = ""
self.project_name = ""
self.inc = False
self.reqa_file = ""
self.max_auto_summarize_code = 0
self.git_reinit = False
self._init_with_config_files_and_env(yaml_file)
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager()
self._update()
global_options.update(OPTIONS.get())
logger.debug("Config loading done.")
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
"""Get first valid LLM provider enum"""
mappings = {
LLMProviderEnum.OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
),
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
LLMProviderEnum.METAGPT: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
),
LLMProviderEnum.AZURE_OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY)
and self.OPENAI_API_TYPE == "azure"
and self.DEPLOYMENT_NAME
and self.OPENAI_API_VERSION
),
LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
}
provider = None
for k, v in mappings.items():
if v:
provider = k
break
if provider is None:
if self.DEFAULT_PROVIDER:
provider = LLMProviderEnum(self.DEFAULT_PROVIDER)
else:
raise NotConfiguredException("You should config a LLM configuration first")
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
model_name = self.get_model_name(provider=provider)
if model_name:
logger.info(f"{provider} Model: {model_name}")
if provider:
logger.info(f"API: {provider}")
return provider
def get_model_name(self, provider=None) -> str:
provider = provider or self.get_default_llm_provider_enum()
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
return model_mappings.get(provider, "")
@staticmethod
def _is_valid_llm_key(k: str) -> bool:
return bool(k and k != "YOUR_API_KEY")
def _update(self):
self.global_proxy = self._get("GLOBAL_PROXY")
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("ANTHROPIC_API_KEY")
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
self.gemini_api_key = self._get("GEMINI_API_KEY")
self.ollama_api_base = self._get("OLLAMA_API_BASE")
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
_ = self.get_default_llm_provider_enum()
self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")
self.openai_api_rpm = self._get("RPM", 3)
self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4-1106-preview")
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)
self.deployment_name = self._get("DEPLOYMENT_NAME", "gpt-4")
self.spark_appid = self._get("SPARK_APPID")
self.spark_api_secret = self._get("SPARK_API_SECRET")
self.spark_api_key = self._get("SPARK_API_KEY")
self.domain = self._get("DOMAIN")
self.spark_url = self._get("SPARK_URL")
self.fireworks_api_base = self._get("FIREWORKS_API_BASE")
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
self.claude_api_key = self._get("ANTHROPIC_API_KEY")
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
self.serper_api_key = self._get("SERPER_API_KEY")
self.google_api_key = self._get("GOOGLE_API_KEY")
self.google_cse_id = self._get("GOOGLE_CSE_ID")
self.search_engine = SearchEngineType(self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE))
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", WebBrowserEngineType.PLAYWRIGHT))
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
if self.long_term_memory:
logger.warning("LONG_TERM_MEMORY is True")
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
self.code_review_k_times = 2
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
self.mmdc = self._get("MMDC", "mmdc")
self.calc_usage = self._get("CALC_USAGE", True)
self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs")
self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "")
workspace_uid = (
self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
)
self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False)
self.prompt_schema = self._get("PROMPT_FORMAT", "json")
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
val = self._get("WORKSPACE_PATH_WITH_UID")
if val and val.lower() == "true": # for agent
self.workspace_path = self.workspace_path / workspace_uid
self._ensure_workspace_exists()
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
self.timeout = int(self._get("TIMEOUT", 3))
self.kaggle_username = self._get("KAGGLE_USERNAME", "")
self.kaggle_key = self._get("KAGGLE_KEY", "")
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""
# Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135.
if project_path:
inc = True
project_name = project_name or Path(project_path).name
self.project_path = project_path
self.project_name = project_name
self.inc = inc
self.reqa_file = reqa_file
self.max_auto_summarize_code = max_auto_summarize_code
def _ensure_workspace_exists(self):
self.workspace_path.mkdir(parents=True, exist_ok=True)
logger.debug(f"WORKSPACE_PATH set to {self.workspace_path}")
def _init_with_config_files_and_env(self, yaml_file):
"""Load from config/key.yaml, config/config.yaml, and env in decreasing order of priority"""
configs = dict(os.environ)
for _yaml_file in [yaml_file, self.key_yaml_file, self.home_yaml_file]:
if not _yaml_file.exists():
continue
# Load local YAML file
with open(_yaml_file, "r", encoding="utf-8") as file:
yaml_data = yaml.safe_load(file)
if not yaml_data:
continue
configs.update(yaml_data)
OPTIONS.set(configs)
@staticmethod
def _get(*args, **kwargs):
i = OPTIONS.get()
return i.get(*args, **kwargs)
def get(self, key, *args, **kwargs):
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
Throw an error if not found."""
value = self._get(key, *args, **kwargs)
if value is None:
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")
return value
def __setattr__(self, name: str, value: Any) -> None:
OPTIONS.get()[name] = value
def __getattr__(self, name: str) -> Any:
i = OPTIONS.get()
return i.get(name)
def set_context(self, options: dict):
"""Update current config"""
if not options:
return
opts = deepcopy(OPTIONS.get())
opts.update(options)
OPTIONS.set(opts)
self._update()
@property
def options(self):
"""Return all key-values"""
return OPTIONS.get()
def new_environ(self):
"""Return a new os.environ object"""
env = os.environ.copy()
i = self.options
env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
CONFIG = Config()

143
metagpt/config2.py Normal file
View file

@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 01:25
@Author : alexanderwu
@File : config2.py
"""
import os
from pathlib import Path
from typing import Dict, Iterable, List, Literal, Optional
from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
from metagpt.configs.s3_config import S3Config
from metagpt.configs.search_config import SearchConfig
from metagpt.configs.workspace_config import WorkspaceConfig
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from metagpt.utils.yaml_model import YamlModel
class CLIParams(BaseModel):
"""CLI parameters"""
project_path: str = ""
project_name: str = ""
inc: bool = False
reqa_file: str = ""
max_auto_summarize_code: int = 0
git_reinit: bool = False
@model_validator(mode="after")
def check_project_path(self):
"""Check project_path and project_name"""
if self.project_path:
self.inc = True
self.project_name = self.project_name or Path(self.project_path).name
return self
class Config(CLIParams, YamlModel):
"""Configurations for MetaGPT"""
# Key Parameters
llm: LLMConfig
# Global Proxy. Will be used if llm.proxy is not set
proxy: str = ""
# Tool Parameters
search: Optional[SearchConfig] = None
browser: BrowserConfig = BrowserConfig()
mermaid: MermaidConfig = MermaidConfig()
# Storage Parameters
s3: Optional[S3Config] = None
redis: Optional[RedisConfig] = None
# Misc Parameters
repair_llm_output: bool = False
prompt_schema: Literal["json", "markdown", "raw"] = "json"
workspace: WorkspaceConfig = WorkspaceConfig()
enable_longterm_memory: bool = False
code_review_k_times: int = 2
# Will be removed in the future
llm_for_researcher_summary: str = "gpt3"
llm_for_researcher_report: str = "gpt3"
METAGPT_TEXT_TO_IMAGE_MODEL_URL: str = ""
language: str = "English"
redis_key: str = "placeholder"
mmdc: str = "mmdc"
puppeteer_config: str = ""
pyppeteer_executable_path: str = ""
IFLYTEK_APP_ID: str = ""
IFLYTEK_API_SECRET: str = ""
IFLYTEK_API_KEY: str = ""
AZURE_TTS_SUBSCRIPTION_KEY: str = ""
AZURE_TTS_REGION: str = ""
mermaid_engine: str = "nodejs"
@classmethod
def from_home(cls, path):
"""Load config from ~/.metagpt/config.yaml"""
pathname = CONFIG_ROOT / path
if not pathname.exists():
return None
return Config.from_yaml_file(pathname)
@classmethod
def default(cls):
"""Load default config
- Priority: env < default_config_paths
- Inside default_config_paths, the latter one overwrites the former one
"""
default_config_paths: List[Path] = [
METAGPT_ROOT / "config/config2.yaml",
Path.home() / ".metagpt/config2.yaml",
]
dicts = [dict(os.environ)]
dicts += [Config.read_yaml(path) for path in default_config_paths]
final = merge_dict(dicts)
return Config(**final)
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""
# Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135.
if project_path:
inc = True
project_name = project_name or Path(project_path).name
self.project_path = project_path
self.project_name = project_name
self.inc = inc
self.reqa_file = reqa_file
self.max_auto_summarize_code = max_auto_summarize_code
def get_openai_llm(self) -> Optional[LLMConfig]:
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
if self.llm.api_type == LLMType.OPENAI:
return self.llm
return None
def get_azure_llm(self) -> Optional[LLMConfig]:
"""Get Azure LLMConfig by name. If no Azure, raise Exception"""
if self.llm.api_type == LLMType.AZURE:
return self.llm
return None
def merge_dict(dicts: Iterable[Dict]) -> Dict:
"""Merge multiple dicts into one, with the latter dict overwriting the former"""
result = {}
for dictionary in dicts:
result.update(dictionary)
return result
config = Config.default()

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/11 14:44
@Time : 2024/1/4 16:33
@Author : alexanderwu
@File : test_action.py
@File : __init__.py
"""

View file

@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:06
@Author : alexanderwu
@File : browser_config.py
"""
from typing import Literal
from metagpt.tools import WebBrowserEngineType
from metagpt.utils.yaml_model import YamlModel
class BrowserConfig(YamlModel):
"""Config for Browser"""
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
driver: Literal["chromium", "firefox", "webkit"] = "chromium"
path: str = ""

View file

@ -0,0 +1,78 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 16:33
@Author : alexanderwu
@File : llm_config.py
"""
from enum import Enum
from typing import Optional
from pydantic import field_validator
from metagpt.utils.yaml_model import YamlModel
class LLMType(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
SPARK = "spark"
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
GEMINI = "gemini"
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
def __missing__(self, key):
return self.OPENAI
class LLMConfig(YamlModel):
"""Config for LLM
OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
"""
api_key: str
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None
api_secret: Optional[str] = None
domain: Optional[str] = None
# For Chat Completion
max_token: int = 4096
temperature: float = 0.0
top_p: float = 1.0
top_k: int = 0
repetition_penalty: float = 1.0
stop: Optional[str] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
best_of: Optional[int] = None
n: Optional[int] = None
stream: bool = False
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
top_logprobs: Optional[int] = None
timeout: int = 60
# For Network
proxy: Optional[str] = None
# Cost Control
calc_usage: bool = True
@field_validator("api_key")
@classmethod
def check_llm_key(cls, v):
if v in ["", None, "YOUR_API_KEY"]:
raise ValueError("Please set your API key in config.yaml")
return v

View file

@ -0,0 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:07
@Author : alexanderwu
@File : mermaid_config.py
"""
from typing import Literal
from metagpt.utils.yaml_model import YamlModel
class MermaidConfig(YamlModel):
"""Config for Mermaid"""
engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs"
path: str = ""
puppeteer_config: str = "" # Only for nodejs engine

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:06
@Author : alexanderwu
@File : redis_config.py
"""
from metagpt.utils.yaml_model import YamlModelWithoutDefault
class RedisConfig(YamlModelWithoutDefault):
host: str
port: int
username: str = ""
password: str
db: str
def to_url(self):
return f"redis://{self.host}:{self.port}"
def to_kwargs(self):
return {
"username": self.username,
"password": self.password,
"db": self.db,
}

View file

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:07
@Author : alexanderwu
@File : s3_config.py
"""
from metagpt.utils.yaml_model import YamlModelWithoutDefault
class S3Config(YamlModelWithoutDefault):
access_key: str
secret_key: str
endpoint: str
bucket: str

View file

@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:06
@Author : alexanderwu
@File : search_config.py
"""
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
class SearchConfig(YamlModel):
"""Config for Search"""
api_key: str
api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
cse_id: str = "" # for google

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 19:09
@Author : alexanderwu
@File : workspace_config.py
"""
from datetime import datetime
from pathlib import Path
from uuid import uuid4
from pydantic import field_validator, model_validator
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.utils.yaml_model import YamlModel
class WorkspaceConfig(YamlModel):
path: Path = DEFAULT_WORKSPACE_ROOT
use_uid: bool = False
uid: str = ""
@field_validator("path")
@classmethod
def check_workspace_path(cls, v):
if isinstance(v, str):
v = Path(v)
return v
@model_validator(mode="after")
def check_uid_and_update_path(self):
if self.use_uid and not self.uid:
self.uid = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
self.path = self.path / self.uid
# Create workspace path if not exists
self.path.mkdir(parents=True, exist_ok=True)
return self

View file

@ -9,7 +9,6 @@
@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135.
@Modified By: mashenquan, 2023/12/5. Add directories for code summarization..
"""
import contextvars
import os
from pathlib import Path
@ -17,8 +16,6 @@ from loguru import logger
import metagpt
OPTIONS = contextvars.ContextVar("OPTIONS", default={})
def get_metagpt_package_root():
"""Get the root directory of the installed package."""
@ -47,7 +44,7 @@ def get_metagpt_root():
# METAGPT PROJECT ROOT AND VARS
CONFIG_ROOT = Path.home() / ".metagpt"
METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
@ -73,12 +70,10 @@ SKILL_DIRECTORY = SOURCE_ROOT / "skills"
TOOL_SCHEMA_PATH = METAGPT_ROOT / "metagpt/tools/schemas"
TOOL_LIBS_PATH = METAGPT_ROOT / "metagpt/tools/libs"
# REAL CONSTS
MEM_TTL = 24 * 30 * 3600
MESSAGE_ROUTE_FROM = "sent_from"
MESSAGE_ROUTE_TO = "send_to"
MESSAGE_ROUTE_CAUSE_BY = "cause_by"
@ -89,25 +84,28 @@ MESSAGE_ROUTE_TO_NONE = "<none>"
REQUIREMENT_FILENAME = "requirement.txt"
BUGFIX_FILENAME = "bugfix.txt"
PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt"
CODE_PLAN_AND_CHANGE_FILENAME = "code_plan_and_change.json"
DOCS_FILE_REPO = "docs"
PRDS_FILE_REPO = "docs/prds"
PRDS_FILE_REPO = "docs/prd"
SYSTEM_DESIGN_FILE_REPO = "docs/system_design"
TASK_FILE_REPO = "docs/tasks"
TASK_FILE_REPO = "docs/task"
CODE_PLAN_AND_CHANGE_FILE_REPO = "docs/code_plan_and_change"
COMPETITIVE_ANALYSIS_FILE_REPO = "resources/competitive_analysis"
DATA_API_DESIGN_FILE_REPO = "resources/data_api_design"
SEQ_FLOW_FILE_REPO = "resources/seq_flow"
SYSTEM_DESIGN_PDF_FILE_REPO = "resources/system_design"
PRD_PDF_FILE_REPO = "resources/prd"
TASK_PDF_FILE_REPO = "resources/api_spec_and_tasks"
TASK_PDF_FILE_REPO = "resources/api_spec_and_task"
CODE_PLAN_AND_CHANGE_PDF_FILE_REPO = "resources/code_plan_and_change"
TEST_CODES_FILE_REPO = "tests"
TEST_OUTPUTS_FILE_REPO = "test_outputs"
CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
CODE_SUMMARIES_FILE_REPO = "docs/code_summary"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
SD_OUTPUT_FILE_REPO = "resources/sd_output"
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
CLASS_VIEW_FILE_REPO = "docs/class_views"
CLASS_VIEW_FILE_REPO = "docs/class_view"
YAPI_URL = "http://yapi.deepwisdomai.com/"

97
metagpt/context.py Normal file
View file

@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 16:32
@Author : alexanderwu
@File : context.py
"""
import os
from pathlib import Path
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class AttrDict(BaseModel):
"""A dict-like object that allows access to keys as attributes, compatible with Pydantic."""
model_config = ConfigDict(extra="allow")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.__dict__.update(kwargs)
def __getattr__(self, key):
return self.__dict__.get(key, None)
def __setattr__(self, key, value):
self.__dict__[key] = value
def __delattr__(self, key):
if key in self.__dict__:
del self.__dict__[key]
else:
raise AttributeError(f"No such attribute: {key}")
def set(self, key, val: Any):
self.__dict__[key] = val
def get(self, key, default: Any = None):
return self.__dict__.get(key, default)
def remove(self, key):
if key in self.__dict__:
self.__delattr__(key)
class Context(BaseModel):
"""Env context for MetaGPT"""
model_config = ConfigDict(arbitrary_types_allowed=True)
kwargs: AttrDict = AttrDict()
config: Config = Config.default()
repo: Optional[ProjectRepo] = None
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
_llm: Optional[BaseLLM] = None
def new_environ(self):
"""Return a new os.environ object"""
env = os.environ.copy()
# i = self.options
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
# """Use a LLM instance"""
# self._llm_config = self.config.get_llm_config(name, provider)
# self._llm = None
# return self._llm
def llm(self) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
self._llm = create_llm_instance(self.config.llm)
if self._llm.cost_manager is None:
self._llm.cost_manager = self.cost_manager
return self._llm
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
llm = create_llm_instance(llm_config)
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
return llm

102
metagpt/context_mixin.py Normal file
View file

@ -0,0 +1,102 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/11 17:25
@Author : alexanderwu
@File : context_mixin.py
"""
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.context import Context
from metagpt.provider.base_llm import BaseLLM
class ContextMixin(BaseModel):
"""Mixin class for context and config"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Pydantic has bug on _private_attr when using inheritance, so we use private_* instead
# - https://github.com/pydantic/pydantic/issues/7142
# - https://github.com/pydantic/pydantic/issues/7083
# - https://github.com/pydantic/pydantic/issues/7091
# Env/Role/Action will use this context as private context, or use self.context as public context
private_context: Optional[Context] = Field(default=None, exclude=True)
# Env/Role/Action will use this config as private config, or use self.context.config as public config
private_config: Optional[Config] = Field(default=None, exclude=True)
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
def __init__(
self,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
**kwargs,
):
"""Initialize with config"""
super().__init__(**kwargs)
self.set_context(context)
self.set_config(config)
self.set_llm(llm)
def set(self, k, v, override=False):
"""Set attribute"""
if override or not self.__dict__.get(k):
self.__dict__[k] = v
def set_context(self, context: Context, override=True):
"""Set context"""
self.set("private_context", context, override)
def set_config(self, config: Config, override=False):
"""Set config"""
self.set("private_config", config, override)
if config is not None:
_ = self.llm # init llm
def set_llm(self, llm: BaseLLM, override=False):
"""Set llm"""
self.set("private_llm", llm, override)
@property
def config(self) -> Config:
"""Role config: role config > context config"""
if self.private_config:
return self.private_config
return self.context.config
@config.setter
def config(self, config: Config) -> None:
"""Set config"""
self.set_config(config)
@property
def context(self) -> Context:
"""Role context: role context > context"""
if self.private_context:
return self.private_context
return Context()
@context.setter
def context(self, context: Context) -> None:
"""Set context"""
self.set_context(context)
@property
def llm(self) -> BaseLLM:
"""Role llm: if not existed, init from role.config"""
# print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}")
if not self.private_llm:
self.private_llm = self.context.llm_with_cost_manager_from_llm_config(self.config.llm)
return self.private_llm
@llm.setter
def llm(self, llm: BaseLLM) -> None:
"""Set llm"""
self.private_llm = llm

View file

@ -8,8 +8,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from metagpt.config import Config
class BaseStore(ABC):
"""FIXME: consider add_index, set_index and think about granularity."""
@ -31,7 +29,6 @@ class LocalStore(BaseStore, ABC):
def __init__(self, raw_data_path: Path, cache_dir: Path = None):
if not raw_data_path:
raise FileNotFoundError
self.config = Config()
self.raw_data_path = raw_data_path
self.fname = self.raw_data_path.stem
if not cache_dir:

View file

@ -9,14 +9,13 @@ import asyncio
from pathlib import Path
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.config import CONFIG
from metagpt.document import IndexableDocument
from metagpt.document_store.base_store import LocalStore
from metagpt.logs import logger
from metagpt.utils.embedding import get_embedding
class FaissStore(LocalStore):
@ -25,9 +24,7 @@ class FaissStore(LocalStore):
):
self.meta_col = meta_col
self.content_col = content_col
self.embedding = embedding or OpenAIEmbeddings(
openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url
)
self.embedding = embedding or get_embedding()
super().__init__(raw_data, cache_dir)
def _load(self) -> Optional["FaissStore"]:

View file

@ -12,16 +12,15 @@
functionality is to be consolidated into the `Environment` class.
"""
import asyncio
from pathlib import Path
from typing import Iterable, Set
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.config import CONFIG
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
from metagpt.utils.common import is_send_to
class Environment(BaseModel):
@ -33,58 +32,22 @@ class Environment(BaseModel):
desc: str = Field(default="") # 环境描述
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
member_addrs: dict[Role, Set] = Field(default_factory=dict, exclude=True)
history: str = "" # For debug
context: Context = Field(default_factory=Context, exclude=True)
@model_validator(mode="after")
def init_roles(self):
self.add_roles(self.roles.values())
return self
def serialize(self, stg_path: Path):
roles_path = stg_path.joinpath("roles.json")
roles_info = []
for role_key, role in self.roles.items():
roles_info.append(
{
"role_class": role.__class__.__name__,
"module_name": role.__module__,
"role_name": role.name,
"role_sub_tags": list(self.members.get(role)),
}
)
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
write_json_file(roles_path, roles_info)
history_path = stg_path.joinpath("history.json")
write_json_file(history_path, {"content": self.history})
@classmethod
def deserialize(cls, stg_path: Path) -> "Environment":
"""stg_path: ./storage/team/environment/"""
roles_path = stg_path.joinpath("roles.json")
roles_info = read_json_file(roles_path)
roles = []
for role_info in roles_info:
# role stored in ./environment/roles/{role_class}_{role_name}
role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}")
role = Role.deserialize(role_path)
roles.append(role)
history = read_json_file(stg_path.joinpath("history.json"))
history = history.get("content")
environment = Environment(**{"history": history})
environment.add_roles(roles)
return environment
def add_role(self, role: Role):
"""增加一个在当前环境的角色
Add a role in the current environment
"""
self.roles[role.profile] = role
role.set_env(self)
role.context = self.context
def add_roles(self, roles: Iterable[Role]):
"""增加一批在当前环境的角色
@ -95,6 +58,7 @@ class Environment(BaseModel):
for role in roles: # setup system message with roles
role.set_env(self)
role.context = self.context
def publish_message(self, message: Message, peekable: bool = True) -> bool:
"""
@ -108,8 +72,8 @@ class Environment(BaseModel):
logger.debug(f"publish_message: {message.dump()}")
found = False
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
for role, subscription in self.members.items():
if is_subscribed(message, subscription):
for role, addrs in self.member_addrs.items():
if is_send_to(message, addrs):
role.put_message(message)
found = True
if not found:
@ -154,15 +118,14 @@ class Environment(BaseModel):
return False
return True
def get_subscription(self, obj):
"""Get the labels for messages to be consumed by the object."""
return self.members.get(obj, {})
def get_addresses(self, obj):
"""Get the addresses of the object."""
return self.member_addrs.get(obj, {})
def set_subscription(self, obj, tags):
"""Set the labels for message to be consumed by the object"""
self.members[obj] = tags
def set_addresses(self, obj, addresses):
"""Set the addresses of the object"""
self.member_addrs[obj] = addresses
@staticmethod
def archive(auto_archive=True):
if auto_archive and CONFIG.git_repo:
CONFIG.git_repo.archive()
def archive(self, auto_archive=True):
if auto_archive and self.context.git_repo:
self.context.git_repo.archive()

View file

@ -13,7 +13,7 @@ import aiofiles
import yaml
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.context import Context
class Example(BaseModel):
@ -73,14 +73,15 @@ class SkillsDeclaration(BaseModel):
skill_data = yaml.safe_load(data)
return SkillsDeclaration(**skill_data)
def get_skill_list(self, entity_name: str = "Assistant") -> Dict:
def get_skill_list(self, entity_name: str = "Assistant", context: Context = None) -> Dict:
"""Return the skill name based on the skill description."""
entity = self.entities.get(entity_name)
if not entity:
return {}
# List of skills that the agent chooses to activate.
agent_skills = CONFIG.agent_skills
ctx = context or Context()
agent_skills = ctx.kwargs.agent_skills
if not agent_skills:
return {}

View file

@ -6,19 +6,19 @@
@File : text_to_embedding.py
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
"""
from metagpt.config import CONFIG
import metagpt.config2
from metagpt.config2 import Config
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs):
async def text_to_embedding(text, model="text-embedding-ada-002", config: Config = metagpt.config2.config):
"""Text to embedding
:param text: The text used for embedding.
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
if CONFIG.OPENAI_API_KEY or openai_api_key:
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)
raise EnvironmentError
openai_api_key = config.get_openai_llm().api_key
proxy = config.get_openai_llm().proxy
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key, proxy=proxy)

View file

@ -8,33 +8,37 @@
"""
import base64
from metagpt.config import CONFIG
import metagpt.config2
from metagpt.config2 import Config
from metagpt.const import BASE64_FORMAT
from metagpt.llm import LLM
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
from metagpt.utils.s3 import S3
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
async def text_to_image(text, size_type: str = "512x512", config: Config = metagpt.config2.config):
"""Text to image
:param text: The text used for image conversion.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
:param model_url: MetaGPT model url
:param config: Config
:return: The image data is returned in Base64 encoding.
"""
image_declaration = "data:image/png;base64,"
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
model_url = config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
if model_url:
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
elif CONFIG.OPENAI_API_KEY or openai_api_key:
binary_data = await oas3_openai_text_to_image(text, size_type)
elif config.get_openai_llm():
llm = LLM(llm_config=config.get_openai_llm())
binary_data = await oas3_openai_text_to_image(text, size_type, llm=llm)
else:
raise ValueError("Missing necessary parameters.")
base64_data = base64.b64encode(binary_data).decode("utf-8")
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
s3 = S3(config.s3)
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT)
if url:
return f"![{text}]({url})"
return image_declaration + base64_data if base64_data else ""

View file

@ -6,8 +6,8 @@
@File : text_to_speech.py
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
"""
from metagpt.config import CONFIG
import metagpt.config2
from metagpt.config2 import Config
from metagpt.const import BASE64_FORMAT
from metagpt.tools.azure_tts import oas3_azsure_tts
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
@ -20,12 +20,7 @@ async def text_to_speech(
voice="zh-CN-XiaomoNeural",
style="affectionate",
role="Girl",
subscription_key="",
region="",
iflytek_app_id="",
iflytek_api_key="",
iflytek_api_secret="",
**kwargs,
config: Config = metagpt.config2.config,
):
"""Text to speech
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
@ -44,23 +39,27 @@ async def text_to_speech(
"""
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
subscription_key = config.AZURE_TTS_SUBSCRIPTION_KEY
region = config.AZURE_TTS_REGION
if subscription_key and region:
audio_declaration = "data:audio/wav;base64,"
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else ""
s3 = S3(config.s3)
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT)
if url:
return f"[{text}]({url})"
return audio_declaration + base64_data if base64_data else base64_data
if (CONFIG.IFLYTEK_APP_ID and CONFIG.IFLYTEK_API_KEY and CONFIG.IFLYTEK_API_SECRET) or (
iflytek_app_id and iflytek_api_key and iflytek_api_secret
):
iflytek_app_id = config.IFLYTEK_APP_ID
iflytek_api_key = config.IFLYTEK_API_KEY
iflytek_api_secret = config.IFLYTEK_API_SECRET
if iflytek_app_id and iflytek_api_key and iflytek_api_secret:
audio_declaration = "data:audio/mp3;base64,"
base64_data = await oas3_iflytek_tts(
text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret
)
s3 = S3()
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else ""
s3 = S3(config.s3)
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT)
if url:
return f"[{text}]({url})"
return audio_declaration + base64_data if base64_data else base64_data

View file

@ -5,20 +5,16 @@
@Author : alexanderwu
@File : llm.py
"""
from typing import Optional
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig
from metagpt.context import Context
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
_ = HumanProvider() # Avoid pre-commit error
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
"""get the default llm provider"""
if provider is None:
provider = CONFIG.get_default_llm_provider_enum()
return LLM_REGISTRY.get_provider(provider)
def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> BaseLLM:
"""get the default llm provider if name is None"""
ctx = context or Context()
if llm_config is not None:
ctx.llm_with_cost_manager_from_llm_config(llm_config)
return ctx.llm()

View file

@ -14,8 +14,8 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.config2 import config
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.logs import logger
from metagpt.provider import MetaGPTLLM
from metagpt.provider.base_llm import BaseLLM
@ -29,9 +29,9 @@ class BrainMemory(BaseModel):
historical_summary: str = ""
last_history_id: str = ""
is_dirty: bool = False
last_talk: str = None
last_talk: Optional[str] = None
cacheable: bool = True
llm: Optional[BaseLLM] = None
llm: Optional[BaseLLM] = Field(default=None, exclude=True)
class Config:
arbitrary_types_allowed = True
@ -56,8 +56,8 @@ class BrainMemory(BaseModel):
@staticmethod
async def loads(redis_key: str) -> "BrainMemory":
redis = Redis()
if not redis.is_valid or not redis_key:
redis = Redis(config.redis)
if not redis_key:
return BrainMemory()
v = await redis.get(key=redis_key)
logger.debug(f"REDIS GET {redis_key} {v}")
@ -70,8 +70,8 @@ class BrainMemory(BaseModel):
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60):
if not self.is_dirty:
return
redis = Redis()
if not redis.is_valid or not redis_key:
redis = Redis(config.redis)
if not redis_key:
return False
v = self.model_dump_json()
if self.cacheable:
@ -83,7 +83,7 @@ class BrainMemory(BaseModel):
def to_redis_key(prefix: str, user_id: str, chat_id: str):
return f"{prefix}:{user_id}:{chat_id}"
async def set_history_summary(self, history_summary, redis_key, redis_conf):
async def set_history_summary(self, history_summary, redis_key):
if self.historical_summary == history_summary:
if self.is_dirty:
await self.dumps(redis_key=redis_key)
@ -140,7 +140,7 @@ class BrainMemory(BaseModel):
return text
summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
if summary:
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
await self.set_history_summary(history_summary=summary, redis_key=config.redis_key)
return summary
raise ValueError(f"text too long:{text_length}")
@ -164,7 +164,7 @@ class BrainMemory(BaseModel):
msgs.reverse()
self.history = msgs
self.is_dirty = True
await self.dumps(redis_key=CONFIG.REDIS_KEY)
await self.dumps(redis_key=config.redis.key)
self.is_dirty = False
return BrainMemory.to_metagpt_history_format(self.history)
@ -181,7 +181,7 @@ class BrainMemory(BaseModel):
summary = await self.summarize(llm=llm, max_words=500)
language = CONFIG.language or DEFAULT_LANGUAGE
language = config.language
command = f"Translate the above summary into a {language} title of less than {max_words} words."
summaries = [summary, command]
msg = "\n".join(summaries)

View file

@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
"""
@Desc : the implement of Long-term memory
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from typing import Optional

View file

@ -7,19 +7,13 @@
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
"""
from collections import defaultdict
from pathlib import Path
from typing import DefaultDict, Iterable, Set
from pydantic import BaseModel, Field, SerializeAsAny
from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
from metagpt.utils.common import (
any_to_str,
any_to_str_set,
read_json_file,
write_json_file,
)
from metagpt.utils.common import any_to_str, any_to_str_set
class Memory(BaseModel):
@ -29,22 +23,6 @@ class Memory(BaseModel):
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
ignore_id: bool = False
def serialize(self, stg_path: Path):
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
storage = self.model_dump()
write_json_file(memory_path, storage)
@classmethod
def deserialize(cls, stg_path: Path) -> "Memory":
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
memory_dict = read_json_file(memory_path)
memory = Memory(**memory_dict)
return memory
def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if self.ignore_id:

View file

@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
"""
@Desc : the implement of memory storage
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from pathlib import Path

View file

@ -14,6 +14,8 @@ from metagpt.provider.openai_api import OpenAILLM
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
__all__ = [
"FireworksLLM",
@ -24,4 +26,6 @@ __all__ = [
"AzureOpenAILLM",
"MetaGPTLLM",
"OllamaLLM",
"HumanProvider",
"SparkLLM",
]

View file

@ -9,12 +9,15 @@
import anthropic
from anthropic import Anthropic, AsyncAnthropic
from metagpt.config import CONFIG
from metagpt.configs.llm_config import LLMConfig
class Claude2:
def __init__(self, config: LLMConfig):
self.config = config
def ask(self, prompt: str) -> str:
client = Anthropic(api_key=CONFIG.anthropic_api_key)
client = Anthropic(api_key=self.config.api_key)
res = client.completions.create(
model="claude-2",
@ -24,7 +27,7 @@ class Claude2:
return res.completion
async def aask(self, prompt: str) -> str:
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
aclient = AsyncAnthropic(api_key=self.config.api_key)
res = await aclient.completions.create(
model="claude-2",

View file

@ -3,8 +3,6 @@
@Time : 2023/5/5 23:08
@Author : alexanderwu
@File : openai.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
Change cost control from global to company level.
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
@ -13,12 +11,12 @@
from openai import AsyncAzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMType
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
@register_provider(LLMProviderEnum.AZURE_OPENAI)
@register_provider(LLMType.AZURE)
class AzureOpenAILLM(OpenAILLM):
"""
Check https://platform.openai.com/examples for examples
@ -28,13 +26,13 @@ class AzureOpenAILLM(OpenAILLM):
kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> dict:
kwargs = dict(
api_key=self.config.OPENAI_API_KEY,
api_version=self.config.OPENAI_API_VERSION,
azure_endpoint=self.config.OPENAI_BASE_URL,
api_key=self.config.api_key,
api_version=self.config.api_version,
azure_endpoint=self.config.base_url,
)
# to use proxy, openai v1 needs http_client

View file

@ -8,15 +8,32 @@
"""
import json
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
from openai import AsyncOpenAI
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
class BaseLLM(ABC):
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
config: LLMConfig
use_system_prompt: bool = True
system_prompt = "You are a helpful assistant."
# OpenAI / Azure / Others
aclient: Optional[Union[AsyncOpenAI]] = None
cost_manager: Optional[CostManager] = None
model: Optional[str] = None
@abstractmethod
def __init__(self, config: LLMConfig):
pass
def _user_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "content": msg}
@ -43,10 +60,13 @@ class BaseLLM(ABC):
if system_msgs:
message = self._system_msgs(system_msgs)
else:
message = [self._default_system_msg()] if self.use_system_prompt else []
message = [self._default_system_msg()]
if not self.use_system_prompt:
message = []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
@ -63,10 +83,9 @@ class BaseLLM(ABC):
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def aask_code(self, msgs: list[str], timeout=3) -> str:
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3) -> dict:
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
rsp_text = await self.aask_batch(msgs, timeout=timeout)
return rsp_text
raise NotImplementedError
@abstractmethod
async def acompletion(self, messages: list[dict], timeout=3):
@ -87,6 +106,10 @@ class BaseLLM(ABC):
"""Required to provide the first text of choice"""
return rsp.get("choices")[0]["message"]["content"]
def get_choice_delta_text(self, rsp: dict) -> str:
"""Required to provide the first text of stream choice"""
return rsp.get("choices")[0]["delta"]["content"]
def get_choice_function(self, rsp: dict) -> dict:
"""Required to provide the first function of choice
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",

View file

@ -15,7 +15,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
@ -64,44 +64,35 @@ class FireworksCostManager(CostManager):
token_costs = self.model_grade_token_costs(model)
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
self.total_cost += cost
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
logger.info(
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${max_budget:.3f} | "
f"Total running cost: ${self.total_cost:.4f}"
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.FIREWORKS)
@register_provider(LLMType.FIREWORKS)
class FireworksLLM(OpenAILLM):
def __init__(self):
self.config: Config = CONFIG
self.__init_fireworks()
def __init__(self, config: LLMConfig):
super().__init__(config=config)
self.auto_max_tokens = False
self._cost_manager = FireworksCostManager()
def __init_fireworks(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._init_client()
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
self.cost_manager = FireworksCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)

View file

@ -79,10 +79,8 @@ class GeneralAPIRequestor(APIRequestor):
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
if stream and (
"text/event-stream" in result.headers.get("Content-Type", "")
or "application/x-ndjson" in result.headers.get("Content-Type", "")
):
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)

View file

@ -19,7 +19,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -41,21 +41,22 @@ class GeminiGenerativeModel(GenerativeModel):
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
@register_provider(LLMProviderEnum.GEMINI)
@register_provider(LLMType.GEMINI)
class GeminiLLM(BaseLLM):
"""
Refs to `https://ai.google.dev/tutorials/python_quickstart`
"""
def __init__(self):
def __init__(self, config: LLMConfig):
self.use_system_prompt = False # google gemini has no system prompt when use api
self.__init_gemini(CONFIG)
self.__init_gemini(config)
self.config = config
self.model = "gemini-pro" # so far only one model
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: CONFIG):
genai.configure(api_key=config.gemini_api_key)
def __init_gemini(self, config: LLMConfig):
genai.configure(api_key=config.api_key)
def _user_msg(self, msg: str) -> dict[str, str]:
# Not to change BaseLLM default functions but update with Gemini's conversation format.
@ -71,11 +72,11 @@ class GeminiLLM(BaseLLM):
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if CONFIG.calc_usage:
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
@ -108,7 +109,7 @@ class GeminiLLM(BaseLLM):
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:

View file

@ -5,6 +5,7 @@ Author: garylin2099
"""
from typing import Optional
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
@ -14,6 +15,9 @@ class HumanProvider(BaseLLM):
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
"""
def __init__(self, config: LLMConfig):
pass
def ask(self, msg: str, timeout=3) -> str:
logger.info("It's your turn, please type in your response. You may also refer to the context below")
rsp = input(msg)

View file

@ -5,7 +5,8 @@
@Author : alexanderwu
@File : llm_provider_registry.py
"""
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.provider.base_llm import BaseLLM
class LLMProviderRegistry:
@ -15,13 +16,9 @@ class LLMProviderRegistry:
def register(self, key, provider_cls):
self.providers[key] = provider_cls
def get_provider(self, enum: LLMProviderEnum):
def get_provider(self, enum: LLMType):
"""get provider instance according to the enum"""
return self.providers[enum]()
# Registry instance
LLM_REGISTRY = LLMProviderRegistry()
return self.providers[enum]
def register_provider(key):
@ -32,3 +29,12 @@ def register_provider(key):
return cls
return decorator
def create_llm_instance(config: LLMConfig) -> BaseLLM:
"""get the default llm provider"""
return LLM_REGISTRY.get_provider(config.api_type)(config)
# Registry instance
LLM_REGISTRY = LLMProviderRegistry()

View file

@ -5,12 +5,11 @@
@File : metagpt_api.py
@Desc : MetaGPT LLM provider.
"""
from metagpt.config import LLMProviderEnum
from metagpt.configs.llm_config import LLMType
from metagpt.provider import OpenAILLM
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.METAGPT)
@register_provider(LLMType.METAGPT)
class MetaGPTLLM(OpenAILLM):
def __init__(self):
super().__init__()
pass

View file

@ -13,48 +13,34 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import TokenCostManager
class OllamaCostManager(CostManager):
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
logger.info(
f"Max budget: ${max_budget:.3f} | "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.OLLAMA)
@register_provider(LLMType.OLLAMA)
class OllamaLLM(BaseLLM):
"""
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
"""
def __init__(self):
self.__init_ollama(CONFIG)
self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base)
def __init__(self, config: LLMConfig):
self.__init_ollama(config)
self.client = GeneralAPIRequestor(base_url=config.base_url)
self.config = config
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self._cost_manager = OllamaCostManager()
self._cost_manager = TokenCostManager()
def __init_ollama(self, config: CONFIG):
assert config.ollama_api_base
self.model = config.ollama_api_model
def __init_ollama(self, config: LLMConfig):
assert config.base_url, "ollama base url is required!"
self.model = config.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
@ -62,7 +48,7 @@ class OllamaLLM(BaseLLM):
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if CONFIG.calc_usage:
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))

View file

@ -4,56 +4,27 @@
from openai.types import CompletionUsage
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
class OpenLLMCostManager(CostManager):
"""open llm model is self-host, it's free and without cost"""
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
logger.info(
f"Max budget: ${max_budget:.3f} | reference "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
@register_provider(LLMProviderEnum.OPEN_LLM)
@register_provider(LLMType.OPEN_LLM)
class OpenLLM(OpenAILLM):
def __init__(self):
self.config: Config = CONFIG
self.__init_openllm()
self.auto_max_tokens = False
self._cost_manager = OpenLLMCostManager()
def __init_openllm(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._init_client()
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
def __init__(self, config: LLMConfig):
super().__init__(config)
self._cost_manager = TokenCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
return kwargs
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not CONFIG.calc_usage:
if not self.config.calc_usage:
return usage
try:

View file

@ -3,15 +3,13 @@
@Time : 2023/5/5 23:08
@Author : alexanderwu
@File : openai.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation;
Change cost control from global to company level.
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
import json
import re
from typing import AsyncIterator, Union
from typing import AsyncIterator, Optional, Union
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
@ -25,14 +23,14 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser
from metagpt.utils.cost_manager import Costs
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -52,18 +50,19 @@ See FAQ 5.8
raise retry_state.outcome.exception()
@register_provider(LLMProviderEnum.OPENAI)
@register_provider(LLMType.OPENAI)
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
def __init__(self):
self.config: Config = CONFIG
self._init_openai()
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_openai(self):
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
@ -71,7 +70,7 @@ class OpenAILLM(BaseLLM):
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
kwargs = {"api_key": self.config.api_key, "base_url": self.config.base_url}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
@ -81,10 +80,10 @@ class OpenAILLM(BaseLLM):
def _get_proxy_params(self) -> dict:
params = {}
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
if self.config.proxy:
params = {"proxies": self.config.proxy}
if self.config.base_url:
params["base_url"] = self.config.base_url
return params
@ -105,7 +104,7 @@ class OpenAILLM(BaseLLM):
"stop": None,
"temperature": 0.3,
"model": self.model,
"timeout": max(CONFIG.timeout, timeout),
"timeout": max(self.config.timeout, timeout),
}
if extra_kwargs:
kwargs.update(extra_kwargs)
@ -266,7 +265,7 @@ class OpenAILLM(BaseLLM):
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not CONFIG.calc_usage:
if not self.config.calc_usage:
return usage
try:
@ -279,18 +278,28 @@ class OpenAILLM(BaseLLM):
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if CONFIG.calc_usage and usage:
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
return CONFIG.cost_manager.get_costs()
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
return self.config.max_token
return get_max_completion_tokens(messages, self.model, self.config.max_tokens)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):
"""Moderate content."""
return await self.aclient.moderations.create(input=content)
async def atext_to_speech(self, **kwargs):
"""text to speech"""
return await self.aclient.audio.speech.create(**kwargs)
async def aspeech_to_text(self, **kwargs):
"""speech to text"""
return await self.aclient.audio.transcriptions.create(**kwargs)

View file

@ -16,29 +16,30 @@ from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.SPARK)
@register_provider(LLMType.SPARK)
class SparkLLM(BaseLLM):
def __init__(self):
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def __init__(self, config: LLMConfig):
self.config = config
logger.warning("SparkLLM当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
# 不支持
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
w = GetMessageFromWeb(messages, self.config)
return w.run()
async def acompletion(self, messages: list[dict], timeout=3):
# 不支持异步
w = GetMessageFromWeb(messages)
w = GetMessageFromWeb(messages, self.config)
return w.run()
@ -89,14 +90,14 @@ class GetMessageFromWeb:
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def __init__(self, text):
def __init__(self, text, config: LLMConfig):
self.text = text
self.ret = ""
self.spark_appid = CONFIG.spark_appid
self.spark_api_secret = CONFIG.spark_api_secret
self.spark_api_key = CONFIG.spark_api_key
self.domain = CONFIG.domain
self.spark_url = CONFIG.spark_url
self.spark_appid = config.app_id
self.spark_api_secret = config.api_secret
self.spark_api_key = config.api_key
self.domain = config.domain
self.spark_url = config.base_url
def on_message(self, ws, message):
data = json.loads(message)

View file

@ -1,75 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : async_sse_client to make keep the use of Event to access response
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
# refs to `zhipuai/core/_sse_client.py`
from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
import json
from typing import Any, Iterator
class AsyncSSEClient(SSEClient):
async def _aread(self):
data = b""
class AsyncSSEClient(object):
def __init__(self, event_source: Iterator[Any]):
self._event_source = event_source
async def stream(self) -> dict:
if isinstance(self._event_source, bytes):
raise RuntimeError(
f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
async for chunk in self._event_source:
for line in chunk.splitlines(True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
line = chunk.decode("utf-8")
if line.startswith(":") or not line:
return
async def async_events(self):
async for chunk in self._aread():
event = Event()
# Split before decoding so splitlines() only uses \r and \n
for line in chunk.splitlines():
# Decode the line.
line = line.decode(self._char_enc)
# Lines starting with a separator are comments and are to be
# ignored.
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
continue
data = line.split(_FIELD_SEPARATOR, 1)
field = data[0]
# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
continue
if len(data) > 1:
# From the spec:
# "If value starts with a single U+0020 SPACE character,
# remove it from value."
if data[1].startswith(" "):
value = data[1][1:]
else:
value = data[1]
else:
# If no value is present after the separator,
# assume an empty value.
value = ""
# The data field may come over multiple lines and their values
# are concatenated with each other.
if field == "data":
event.__dict__[field] += value + "\n"
else:
event.__dict__[field] = value
# Events with no data are not dispatched.
if not event.data:
continue
# If the data field ends with a newline, remove it.
if event.data.endswith("\n"):
event.data = event.data[0:-1]
# Empty event names default to 'message'
event.event = event.event or "message"
# Dispatch the event
self._logger.debug("Dispatching %s...", event)
yield event
field, _p, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if field == "data":
if value.startswith("[DONE]"):
break
data = json.loads(value)
yield data

View file

@ -4,46 +4,27 @@
import json
import zhipuai
from zhipuai.model_api.api import InvokeType, ModelAPI
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from zhipuai import ZhipuAI
from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
class ZhiPuModelAPI(ModelAPI):
@classmethod
def get_header(cls) -> dict:
token = cls._generate_token()
zhipuai_default_headers.update({"Authorization": token})
return zhipuai_default_headers
@classmethod
def get_sse_header(cls) -> dict:
token = cls._generate_token()
headers = {"Authorization": token}
return headers
@classmethod
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
class ZhiPuModelAPI(ZhipuAI):
def split_zhipu_api_url(self):
# use this method to prevent zhipu api upgrading to different version.
# and follow the GeneralAPIRequestor implemented based on openai sdk
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
"""
example:
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
# ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions")
return f"{arr[0]}/api", f"/{arr[1]}"
@classmethod
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
async def arequest(self, stream: bool, method: str, headers: dict, kwargs):
# TODO to make the async request to be more generic for models in http mode.
assert method in ["post", "get"]
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
base_url, url = self.split_zhipu_api_url()
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
@ -51,25 +32,23 @@ class ZhiPuModelAPI(ModelAPI):
headers=headers,
stream=stream,
params=kwargs,
request_timeout=zhipuai.api_timeout_seconds,
request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read,
)
return result
@classmethod
async def ainvoke(cls, **kwargs) -> dict:
async def acreate(self, **kwargs) -> dict:
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
headers = cls.get_header()
resp = await cls.arequest(
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
)
headers = self._default_headers
resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs)
resp = resp.decode("utf-8")
resp = json.loads(resp)
if "error" in resp:
raise RuntimeError(
f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
return resp
@classmethod
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
async def acreate_stream(self, **kwargs) -> AsyncSSEClient:
"""async sse_invoke"""
headers = cls.get_sse_header()
return AsyncSSEClient(
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
)
headers = self._default_headers
return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs))

View file

@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
import json
from enum import Enum
import openai
@ -16,7 +15,7 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -31,57 +30,52 @@ class ZhiPuEvent(Enum):
FINISH = "finish"
@register_provider(LLMProviderEnum.ZHIPUAI)
@register_provider(LLMType.ZHIPUAI)
class ZhiPuAILLM(BaseLLM):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`
From now, support glm-3-turboglm-4, and also system_prompt.
"""
def __init__(self):
self.__init_zhipuai(CONFIG)
def __init__(self, config: LLMConfig):
self.__init_zhipuai(config)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
self.config = config
def __init_zhipuai(self, config: CONFIG):
assert config.zhipuai_api_key
zhipuai.api_key = config.zhipuai_api_key
def __init_zhipuai(self, config: LLMConfig):
assert config.api_key
zhipuai.api_key = config.api_key
# due to use openai sdk, set the api_key but it will't be used.
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
if config.openai_proxy:
if config.proxy:
# FIXME: openai v1.x sdk has no proxy support
openai.proxy = config.openai_proxy
openai.proxy = config.proxy
def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if CONFIG.calc_usage:
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
self._update_costs(usage)
return resp
return resp.model_dump()
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp = await self.llm.acreate(**self._const_kwargs(messages))
usage = resp.get("usage", {})
self._update_costs(usage)
return resp
@ -89,35 +83,18 @@ class ZhiPuAILLM(BaseLLM):
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for event in response.async_events():
if event.event == ZhiPuEvent.ADD.value:
content = event.data
async for chunk in response.stream():
finish_reason = chunk.get("choices")[0].get("finish_reason")
if finish_reason == "stop":
usage = chunk.get("usage", {})
else:
content = self.get_choice_delta_text(chunk)
collected_content.append(content)
log_llm_stream(content)
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
content = event.data
logger.error(f"event error: {content}", end="")
elif event.event == ZhiPuEvent.FINISH.value:
"""
event.meta
{
"task_status":"SUCCESS",
"usage":{
"completion_tokens":351,
"prompt_tokens":595,
"total_tokens":946
},
"task_id":"xx",
"request_id":"xxx"
}
"""
meta = json.loads(event.meta)
usage = meta.get("usage")
else:
print(f"zhipuapi else event: {event.data}", end="")
log_llm_stream("\n")
self._update_costs(usage)

View file

@ -33,7 +33,7 @@ class Architect(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
# Initialize actions specific to the Architect role
self._init_actions([WriteDesign])
self.set_actions([WriteDesign])
# Set events or actions the Architect should watch or be aware of
self._watch({WritePRD})

View file

@ -22,7 +22,6 @@ from pydantic import Field
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
from metagpt.actions.talk_action import TalkAction
from metagpt.config import CONFIG
from metagpt.learn.skill_loader import SkillsDeclaration
from metagpt.logs import logger
from metagpt.memory.brain_memory import BrainMemory
@ -48,7 +47,8 @@ class Assistant(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.constraints = self.constraints.format(language=kwargs.get("language") or CONFIG.language or "Chinese")
language = kwargs.get("language") or self.context.kwargs.language
self.constraints = self.constraints.format(language=language)
async def think(self) -> bool:
"""Everything will be done part by part."""
@ -56,16 +56,16 @@ class Assistant(Role):
if not last_talk:
return False
if not self.skills:
skill_path = Path(CONFIG.SKILL_PATH) if CONFIG.SKILL_PATH else None
skill_path = Path(self.context.kwargs.SKILL_PATH) if self.context.kwargs.SKILL_PATH else None
self.skills = await SkillsDeclaration.load(skill_yaml_file_name=skill_path)
prompt = ""
skills = self.skills.get_skill_list()
skills = self.skills.get_skill_list(context=self.context)
for desc, name in skills.items():
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
rsp = await self.llm.aask(prompt, [])
rsp = await self.llm.aask(prompt, ["You are an action classifier"])
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
return await self._plan(rsp, last_talk=last_talk)
@ -97,8 +97,8 @@ class Assistant(Role):
async def talk_handler(self, text, **kwargs) -> bool:
history = self.memory.history_text
text = kwargs.get("last_talk") or text
self.rc.todo = TalkAction(
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
self.set_todo(
TalkAction(i_context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm)
)
return True
@ -108,11 +108,11 @@ class Assistant(Role):
if not skill:
logger.info(f"skill not found: {text}")
return await self.talk_handler(text=last_talk, **kwargs)
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk, **kwargs)
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk)
await action.run(**kwargs)
if action.args is None:
return await self.talk_handler(text=last_talk, **kwargs)
self.rc.todo = SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description)
self.set_todo(SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description))
return True
async def refine_memory(self) -> str:

View file

@ -20,23 +20,27 @@
from __future__ import annotations
import json
import os
from collections import defaultdict
from pathlib import Path
from typing import Set
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
from metagpt.actions.fix_bug import FixBug
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
from metagpt.const import (
CODE_SUMMARIES_FILE_REPO,
CODE_SUMMARIES_PDF_FILE_REPO,
CODE_PLAN_AND_CHANGE_FILE_REPO,
CODE_PLAN_AND_CHANGE_FILENAME,
REQUIREMENT_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import (
CodePlanAndChangeContext,
CodeSummarizeContext,
CodingContext,
Document,
@ -80,12 +84,13 @@ class Engineer(Role):
code_todos: list = []
summarize_todos: list = []
next_todo_action: str = ""
n_summarize: int = 0
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._init_actions([WriteCode])
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug])
self.set_actions([WriteCode])
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug, WriteCodePlanAndChange])
self.code_todos = []
self.summarize_todos = []
self.next_todo_action = any_to_name(WriteCode)
@ -93,11 +98,10 @@ class Engineer(Role):
@staticmethod
def _parse_tasks(task_msg: Document) -> list[str]:
m = json.loads(task_msg.content)
return m.get("Task list")
return m.get(TASK_LIST.key) or m.get(REFINED_TASK_LIST.key)
async def _act_sp_with_cr(self, review=False) -> Set[str]:
changed_files = set()
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
for todo in self.code_todos:
"""
# Select essential information from the historical data to reduce the length of the prompt (summarized from human experience):
@ -109,12 +113,16 @@ class Engineer(Role):
coding_context = await todo.run()
# Code review
if review:
action = WriteCodeReview(context=coding_context, llm=self.llm)
self._init_action_system_message(action)
action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm)
self._init_action(action)
coding_context = await action.run()
await src_file_repo.save(
coding_context.filename,
dependencies={coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path},
dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path}
if self.config.inc:
dependencies.add(os.path.join(CODE_PLAN_AND_CHANGE_FILE_REPO, CODE_PLAN_AND_CHANGE_FILENAME))
await self.project_repo.srcs.save(
filename=coding_context.filename,
dependencies=dependencies,
content=coding_context.code_doc.content,
)
msg = Message(
@ -134,6 +142,9 @@ class Engineer(Role):
"""Determines the mode of action based on whether code review is used."""
if self.rc.todo is None:
return None
if isinstance(self.rc.todo, WriteCodePlanAndChange):
self.next_todo_action = any_to_name(WriteCode)
return await self._act_code_plan_and_change()
if isinstance(self.rc.todo, WriteCode):
self.next_todo_action = any_to_name(SummarizeCode)
return await self._act_write_code()
@ -153,34 +164,32 @@ class Engineer(Role):
)
async def _act_summarize(self):
code_summaries_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO)
code_summaries_pdf_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO)
tasks = []
src_relative_path = CONFIG.src_workspace.relative_to(CONFIG.git_repo.workdir)
for todo in self.summarize_todos:
summary = await todo.run()
summary_filename = Path(todo.context.design_filename).with_suffix(".md").name
dependencies = {todo.context.design_filename, todo.context.task_filename}
for filename in todo.context.codes_filenames:
rpath = src_relative_path / filename
summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name
dependencies = {todo.i_context.design_filename, todo.i_context.task_filename}
for filename in todo.i_context.codes_filenames:
rpath = self.project_repo.src_relative_path / filename
dependencies.add(str(rpath))
await code_summaries_pdf_file_repo.save(
await self.project_repo.resources.code_summary.save(
filename=summary_filename, content=summary, dependencies=dependencies
)
is_pass, reason = await self._is_pass(summary)
if not is_pass:
todo.context.reason = reason
tasks.append(todo.context.dict())
await code_summaries_file_repo.save(
filename=Path(todo.context.design_filename).name,
content=todo.context.model_dump_json(),
todo.i_context.reason = reason
tasks.append(todo.i_context.model_dump())
await self.project_repo.docs.code_summary.save(
filename=Path(todo.i_context.design_filename).name,
content=todo.i_context.model_dump_json(),
dependencies=dependencies,
)
else:
await code_summaries_file_repo.delete(filename=Path(todo.context.design_filename).name)
await self.project_repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name)
logger.info(f"--max-auto-summarize-code={CONFIG.max_auto_summarize_code}")
if not tasks or CONFIG.max_auto_summarize_code == 0:
logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}")
if not tasks or self.config.max_auto_summarize_code == 0:
return Message(
content="",
role=self.profile,
@ -190,11 +199,39 @@ class Engineer(Role):
)
# The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited.
# This parameter is used for debugging the workflow.
CONFIG.max_auto_summarize_code -= 1 if CONFIG.max_auto_summarize_code > 0 else 0
self.n_summarize += 1 if self.config.max_auto_summarize_code > self.n_summarize else 0
return Message(
content=json.dumps(tasks), role=self.profile, cause_by=SummarizeCode, send_to=self, sent_from=self
)
async def _act_code_plan_and_change(self):
"""Write code plan and change that guides subsequent WriteCode and WriteCodeReview"""
logger.info("Writing code plan and change..")
node = await self.rc.todo.run()
code_plan_and_change = node.instruct_content.model_dump_json()
dependencies = {
REQUIREMENT_FILENAME,
self.rc.todo.i_context.prd_filename,
self.rc.todo.i_context.design_filename,
self.rc.todo.i_context.task_filename,
}
await self.project_repo.docs.code_plan_and_change.save(
filename=self.rc.todo.i_context.filename, content=code_plan_and_change, dependencies=dependencies
)
await self.project_repo.resources.code_plan_and_change.save(
filename=Path(self.rc.todo.i_context.filename).with_suffix(".md").name,
content=node.content,
dependencies=dependencies,
)
return Message(
content=code_plan_and_change,
role=self.profile,
cause_by=WriteCodePlanAndChange,
send_to=self,
sent_from=self,
)
async def _is_pass(self, summary) -> (str, str):
rsp = await self.llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
logger.info(rsp)
@ -203,13 +240,18 @@ class Engineer(Role):
return False, rsp
async def _think(self) -> Action | None:
if not CONFIG.src_workspace:
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug])
if not self.src_workspace:
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
write_plan_and_change_filters = any_to_str_set([WriteTasks])
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode, FixBug])
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
if not self.rc.news:
return None
msg = self.rc.news[0]
if self.config.inc and msg.cause_by in write_plan_and_change_filters:
logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}")
await self._new_code_plan_and_change_action()
return self.rc.todo
if msg.cause_by in write_code_filters:
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
await self._new_code_actions(bug_fix=msg.cause_by == any_to_str(FixBug))
@ -220,60 +262,54 @@ class Engineer(Role):
return self.rc.todo
return None
@staticmethod
async def _new_coding_context(
filename, src_file_repo, task_file_repo, design_file_repo, dependency
) -> CodingContext:
old_code_doc = await src_file_repo.get(filename)
async def _new_coding_context(self, filename, dependency) -> CodingContext:
old_code_doc = await self.project_repo.srcs.get(filename)
if not old_code_doc:
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content="")
old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="")
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
task_doc = None
design_doc = None
for i in dependencies:
if str(i.parent) == TASK_FILE_REPO:
task_doc = await task_file_repo.get(i.name)
task_doc = await self.project_repo.docs.task.get(i.name)
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
design_doc = await design_file_repo.get(i.name)
design_doc = await self.project_repo.docs.system_design.get(i.name)
if not task_doc or not design_doc:
logger.error(f'Detected source code "{filename}" from an unknown origin.')
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
return context
@staticmethod
async def _new_coding_doc(filename, src_file_repo, task_file_repo, design_file_repo, dependency):
context = await Engineer._new_coding_context(
filename, src_file_repo, task_file_repo, design_file_repo, dependency
)
async def _new_coding_doc(self, filename, dependency):
context = await self._new_coding_context(filename, dependency)
coding_doc = Document(
root_path=str(src_file_repo.root_path), filename=filename, content=context.model_dump_json()
root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json()
)
return coding_doc
async def _new_code_actions(self, bug_fix=False):
# Prepare file repos
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
changed_src_files = src_file_repo.all_files if bug_fix else src_file_repo.changed_files
task_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
changed_task_files = task_file_repo.changed_files
design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_src_files = self.project_repo.srcs.all_files if bug_fix else self.project_repo.srcs.changed_files
changed_task_files = self.project_repo.docs.task.changed_files
changed_files = Documents()
# Recode caused by upstream changes.
for filename in changed_task_files:
design_doc = await design_file_repo.get(filename)
task_doc = await task_file_repo.get(filename)
design_doc = await self.project_repo.docs.system_design.get(filename)
task_doc = await self.project_repo.docs.task.get(filename)
task_list = self._parse_tasks(task_doc)
for task_filename in task_list:
old_code_doc = await src_file_repo.get(task_filename)
old_code_doc = await self.project_repo.srcs.get(task_filename)
if not old_code_doc:
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=task_filename, content="")
old_code_doc = Document(
root_path=str(self.project_repo.src_relative_path), filename=task_filename, content=""
)
context = CodingContext(
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
)
coding_doc = Document(
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.model_dump_json()
root_path=str(self.project_repo.src_relative_path),
filename=task_filename,
content=context.model_dump_json(),
)
if task_filename in changed_files.docs:
logger.warning(
@ -281,41 +317,44 @@ class Engineer(Role):
f"{changed_files.docs[task_filename].model_dump_json()}"
)
changed_files.docs[task_filename] = coding_doc
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
self.code_todos = [
WriteCode(i_context=i, context=self.context, llm=self.llm) for i in changed_files.docs.values()
]
# Code directly modified by the user.
dependency = await CONFIG.git_repo.get_dependency()
dependency = await self.git_repo.get_dependency()
for filename in changed_src_files:
if filename in changed_files.docs:
continue
coding_doc = await self._new_coding_doc(
filename=filename,
src_file_repo=src_file_repo,
task_file_repo=task_file_repo,
design_file_repo=design_file_repo,
dependency=dependency,
)
coding_doc = await self._new_coding_doc(filename=filename, dependency=dependency)
changed_files.docs[filename] = coding_doc
self.code_todos.append(WriteCode(context=coding_doc, llm=self.llm))
self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm))
if self.code_todos:
self.rc.todo = self.code_todos[0]
self.set_todo(self.code_todos[0])
async def _new_summarize_actions(self):
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
src_files = src_file_repo.all_files
src_files = self.project_repo.srcs.all_files
# Generate a SummarizeCode action for each pair of (system_design_doc, task_doc).
summarizations = defaultdict(list)
for filename in src_files:
dependencies = await src_file_repo.get_dependency(filename=filename)
ctx = CodeSummarizeContext.loads(filenames=dependencies)
dependencies = await self.project_repo.srcs.get_dependency(filename=filename)
ctx = CodeSummarizeContext.loads(filenames=list(dependencies))
summarizations[ctx].append(filename)
for ctx, filenames in summarizations.items():
ctx.codes_filenames = filenames
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self.llm))
self.summarize_todos.append(SummarizeCode(i_context=ctx, context=self.context, llm=self.llm))
if self.summarize_todos:
self.rc.todo = self.summarize_todos[0]
self.set_todo(self.summarize_todos[0])
async def _new_code_plan_and_change_action(self):
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""
files = self.project_repo.all_files
requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME)
requirement = requirement_doc.content if requirement_doc else ""
code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, requirement=requirement)
self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm)
@property
def todo(self) -> str:
def action_description(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.next_todo_action

View file

@ -60,7 +60,7 @@ class InvoiceOCRAssistant(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([InvoiceOCR])
self.set_actions([InvoiceOCR])
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
async def _act(self) -> Message:
@ -82,12 +82,12 @@ class InvoiceOCRAssistant(Role):
resp = await todo.run(file_path)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self._init_actions([GenerateTable, ReplyQuestion])
self.set_actions([GenerateTable, ReplyQuestion])
self.orc_data = resp[0]
else:
self._init_actions([GenerateTable])
self.set_actions([GenerateTable])
self.rc.todo = None
self.set_todo(None)
content = INVOICE_OCR_SUCCESS
resp = OCRResults(ocr_result=json.dumps(resp))
msg = Message(content=content, instruct_content=resp)

View file

@ -9,7 +9,6 @@
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.config import CONFIG
from metagpt.roles.role import Role
from metagpt.utils.common import any_to_name
@ -34,24 +33,19 @@ class ProductManager(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._init_actions([PrepareDocuments, WritePRD])
self.set_actions([PrepareDocuments, WritePRD])
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(PrepareDocuments)
async def _think(self) -> bool:
"""Decide what to do"""
if CONFIG.git_repo and not CONFIG.git_reinit:
if self.git_repo and not self.config.git_reinit:
self._set_state(1)
else:
self._set_state(0)
CONFIG.git_reinit = False
self.config.git_reinit = False
self.todo_action = any_to_name(WritePRD)
return bool(self.rc.todo)
async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.todo_action

View file

@ -33,5 +33,5 @@ class ProjectManager(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._init_actions([WriteTasks])
self.set_actions([WriteTasks])
self._watch([WriteDesign])

View file

@ -15,20 +15,13 @@
of SummarizeCode.
"""
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.const import (
MESSAGE_ROUTE_TO_NONE,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.const import MESSAGE_ROUTE_TO_NONE
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Document, Message, RunCodeContext, TestingContext
from metagpt.utils.common import any_to_str_set, parse_recipient
from metagpt.utils.file_repository import FileRepository
class QaEngineer(Role):
@ -36,7 +29,8 @@ class QaEngineer(Role):
profile: str = "QaEngineer"
goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs"
constraints: str = (
"The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain"
"The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain."
"Use same language as user requirement"
)
test_round_allowed: int = 5
test_round: int = 0
@ -46,34 +40,35 @@ class QaEngineer(Role):
# FIXME: a bit hack here, only init one action to circumvent _think() logic,
# will overwrite _think() in future updates
self._init_actions([WriteTest])
self.set_actions([WriteTest])
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
self.test_round = 0
async def _write_test(self, message: Message) -> None:
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
changed_files = set(src_file_repo.changed_files.keys())
# Unit tests only.
if CONFIG.reqa_file and CONFIG.reqa_file not in changed_files:
changed_files.add(CONFIG.reqa_file)
tests_file_repo = CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO)
if self.config.reqa_file and self.config.reqa_file not in changed_files:
changed_files.add(self.config.reqa_file)
for filename in changed_files:
# write tests
if not filename or "test" in filename:
continue
code_doc = await src_file_repo.get(filename)
test_doc = await tests_file_repo.get("test_" + code_doc.filename)
if not code_doc:
continue
if not code_doc.filename.endswith(".py"):
continue
test_doc = await self.project_repo.tests.get("test_" + code_doc.filename)
if not test_doc:
test_doc = Document(
root_path=str(tests_file_repo.root_path), filename="test_" + code_doc.filename, content=""
root_path=str(self.project_repo.tests.root_path), filename="test_" + code_doc.filename, content=""
)
logger.info(f"Writing {test_doc.filename}..")
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
context = await WriteTest(context=context, llm=self.llm).run()
await tests_file_repo.save(
filename=context.test_doc.filename,
content=context.test_doc.content,
dependencies={context.code_doc.root_relative_path},
context = await WriteTest(i_context=context, context=self.context, llm=self.llm).run()
await self.project_repo.tests.save_doc(
doc=context.test_doc, dependencies={context.code_doc.root_relative_path}
)
# prepare context for run tests in next round
@ -81,8 +76,8 @@ class QaEngineer(Role):
command=["python", context.test_doc.root_relative_path],
code_filename=context.code_doc.filename,
test_filename=context.test_doc.filename,
working_directory=str(CONFIG.git_repo.workdir),
additional_python_paths=[str(CONFIG.src_workspace)],
working_directory=str(self.project_repo.workdir),
additional_python_paths=[str(self.context.src_workspace)],
)
self.publish_message(
Message(
@ -94,21 +89,23 @@ class QaEngineer(Role):
)
)
logger.info(f"Done {str(tests_file_repo.workdir)} generating.")
logger.info(f"Done {str(self.project_repo.tests.workdir)} generating.")
async def _run_code(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
src_doc = await CONFIG.git_repo.new_file_repository(CONFIG.src_workspace).get(run_code_context.code_filename)
src_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
run_code_context.code_filename
)
if not src_doc:
return
test_doc = await CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(run_code_context.test_filename)
test_doc = await self.project_repo.tests.get(run_code_context.test_filename)
if not test_doc:
return
run_code_context.code = src_doc.content
run_code_context.test_code = test_doc.content
result = await RunCode(context=run_code_context, llm=self.llm).run()
result = await RunCode(i_context=run_code_context, context=self.context, llm=self.llm).run()
run_code_context.output_filename = run_code_context.test_filename + ".json"
await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
await self.project_repo.test_outputs.save(
filename=run_code_context.output_filename,
content=result.model_dump_json(),
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
@ -130,10 +127,8 @@ class QaEngineer(Role):
async def _debug_error(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
code = await DebugError(context=run_code_context, llm=self.llm).run()
await FileRepository.save_file(
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
)
code = await DebugError(i_context=run_code_context, context=self.context, llm=self.llm).run()
await self.project_repo.tests.save(filename=run_code_context.test_filename, content=code)
run_code_context.output = None
self.publish_message(
Message(

View file

@ -34,7 +34,7 @@ class Researcher(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions(
self.set_actions(
[CollectLinks(name=self.name), WebBrowseAndSummarize(name=self.name), ConductResearch(name=self.name)]
)
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
@ -49,7 +49,7 @@ class Researcher(Role):
if self.rc.state + 1 < len(self.states):
self._set_state(self.rc.state + 1)
else:
self.rc.todo = None
self.set_todo(None)
return False
async def _act(self) -> Message:

View file

@ -23,29 +23,21 @@
from __future__ import annotations
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, Optional, Set, Type
from typing import Any, Iterable, Optional, Set, Type, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
from metagpt.llm import LLM, HumanProvider
from metagpt.context_mixin import ContextMixin
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.plan.planner import Planner
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, MessageQueue, SerializationMixin, Task, TaskResult
from metagpt.utils.common import (
any_to_name,
any_to_str,
import_class,
read_json_file,
role_raise_decorator,
write_json_file,
)
from metagpt.provider import HumanProvider
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.project_repo import ProjectRepo
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
@ -113,7 +105,7 @@ class RoleContext(BaseModel):
max_react_loop: int = 1
def check(self, role_id: str):
# if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
# self.long_term_memory.recover_memory(role_id, self)
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
pass
@ -128,10 +120,10 @@ class RoleContext(BaseModel):
return self.memory.get()
class Role(SerializationMixin, is_polymorphic_base=True):
class Role(SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
name: str = ""
profile: str = ""
@ -140,12 +132,18 @@ class Role(SerializationMixin, is_polymorphic_base=True):
desc: str = ""
is_human: bool = False
llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message
role_id: str = ""
states: list[str] = []
# scenarios to set action system_prompt:
# 1. `__init__` while using Role(actions=[...])
# 2. add action to role while using `role.set_action(action)`
# 3. set_todo while using `role.set_todo(action)`
# 4. when role.system_prompt is being updated (e.g. by `role.system_prompt = "..."`)
# Additional, if llm is not set, we will use role's llm
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
rc: RoleContext = Field(default_factory=RoleContext)
subscription: set[str] = set()
addresses: set[str] = set()
planner: Planner = None
# builtin variables
@ -154,27 +152,85 @@ class Role(SerializationMixin, is_polymorphic_base=True):
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
@model_validator(mode="after")
def check_subscription(self):
if not self.subscription:
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
return self
def __init__(self, **data: Any):
# --- avoid PydanticUndefinedAnnotation name 'Environment' is not defined #
from metagpt.environment import Environment
Environment
# ------
Role.model_rebuild()
self.pydantic_rebuild_model()
super().__init__(**data)
if self.is_human:
self.llm = HumanProvider()
self.llm = HumanProvider(None)
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
if self.latest_observed_msg:
self.recovered = True
@staticmethod
def pydantic_rebuild_model():
"""Rebuild model to avoid `RecursionError: maximum recursion depth exceeded in comparison`"""
from metagpt.environment import Environment
Environment
Role.model_rebuild()
@property
def todo(self) -> Action:
"""Get action to do"""
return self.rc.todo
def set_todo(self, value: Optional[Action]):
"""Set action to do and update context"""
if value:
value.context = self.context
self.rc.todo = value
@property
def git_repo(self):
"""Git repo"""
return self.context.git_repo
@git_repo.setter
def git_repo(self, value):
self.context.git_repo = value
@property
def src_workspace(self):
"""Source workspace under git repo"""
return self.context.src_workspace
@src_workspace.setter
def src_workspace(self, value):
self.context.src_workspace = value
@property
def project_repo(self) -> ProjectRepo:
project_repo = ProjectRepo(self.context.git_repo)
return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo
@property
def prompt_schema(self):
"""Prompt schema: json/markdown"""
return self.config.prompt_schema
@property
def project_name(self):
return self.config.project_name
@project_name.setter
def project_name(self, value):
self.config.project_name = value
@property
def project_path(self):
return self.config.project_path
@model_validator(mode="after")
def check_addresses(self):
if not self.addresses:
self.addresses = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
return self
def _reset(self):
self.states = []
self.actions = []
@ -183,59 +239,32 @@ class Role(SerializationMixin, is_polymorphic_base=True):
def _setting(self):
return f"{self.name}({self.profile})"
def serialize(self, stg_path: Path = None):
stg_path = (
SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}")
if stg_path is None
else stg_path
)
def _check_actions(self):
"""Check actions and set llm and prefix for each action."""
self.set_actions(self.actions)
return self
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
role_info_path = stg_path.joinpath("role_info.json")
write_json_file(role_info_path, role_info)
self.rc.memory.serialize(stg_path) # serialize role's memory alone
@classmethod
def deserialize(cls, stg_path: Path) -> "Role":
"""stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
role_info_path = stg_path.joinpath("role_info.json")
role_info = read_json_file(role_info_path)
role_class_str = role_info.pop("role_class")
module_name = role_info.pop("module_name")
role_class = import_class(class_name=role_class_str, module_name=module_name)
role = role_class(**role_info) # initiate particular Role
role.set_recovered(True) # set True to make a tag
role_memory = Memory.deserialize(stg_path)
role.set_memory(role_memory)
return role
def _init_action_system_message(self, action: Action):
def _init_action(self, action: Action):
if not action.private_config:
action.set_llm(self.llm, override=True)
else:
action.set_llm(self.llm, override=False)
action.set_prefix(self._get_prefix())
def refresh_system_message(self):
self.llm.system_prompt = self._get_prefix()
def set_action(self, action: Action):
"""Add action to the role."""
self.set_actions([action])
def set_recovered(self, recovered: bool = False):
self.recovered = recovered
def set_actions(self, actions: list[Union[Action, Type[Action]]]):
"""Add actions to the role.
def set_memory(self, memory: Memory):
self.rc.memory = memory
def init_actions(self, actions):
self._init_actions(actions)
def _init_actions(self, actions):
Args:
actions: list of Action classes or instances
"""
self._reset()
for idx, action in enumerate(actions):
for action in actions:
if not isinstance(action, Action):
## 默认初始化
i = action(name="", llm=self.llm)
i = action(context=self.context)
else:
if self.is_human and not isinstance(action.llm, HumanProvider):
logger.warning(
@ -244,9 +273,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
f"try passing in Action classes instead of initialized instances"
)
i = action
self._init_action_system_message(i)
self._init_action(i)
self.actions.append(i)
self.states.append(f"{idx}. {action}")
self.states.append(f"{len(self.actions)}. {action}")
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
"""Set strategy of the Role reacting to observed Message. Variation lies in how
@ -284,33 +313,29 @@ class Role(SerializationMixin, is_polymorphic_base=True):
def is_watch(self, caused_by: str):
return caused_by in self.rc.watch
def subscribe(self, tags: Set[str]):
def set_addresses(self, addresses: Set[str]):
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
buffer to be further processed in _observe. By default, a Role subscribes Messages with a tag of its own name
or profile.
"""
self.subscription = tags
self.addresses = addresses
if self.rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
self.rc.env.set_subscription(self, self.subscription)
self.rc.env.set_addresses(self, self.addresses)
def _set_state(self, state: int):
"""Update the current state."""
self.rc.state = state
logger.debug(f"actions={self.actions}, state={state}")
self.rc.todo = self.actions[self.rc.state] if state >= 0 else None
self.set_todo(self.actions[self.rc.state] if state >= 0 else None)
def set_env(self, env: "Environment"):
"""Set the environment in which the role works. The role can talk to the environment and can also receive
messages by observing."""
self.rc.env = env
if env:
env.set_subscription(self, self.subscription)
self.refresh_system_message() # add env message to system message
@property
def action_count(self):
"""Return number of action"""
return len(self.actions)
env.set_addresses(self, self.addresses)
self.llm.system_prompt = self._get_prefix()
self.set_actions(self.actions) # reset actions to update llm and prefix
def _get_prefix(self):
"""Get the role prefix"""
@ -323,7 +348,8 @@ class Role(SerializationMixin, is_polymorphic_base=True):
prefix += CONSTRAINT_TEMPLATE.format(**{"constraints": self.constraints})
if self.rc.env and self.rc.env.desc:
other_role_names = ", ".join(self.rc.env.role_names())
all_roles = self.rc.env.role_names()
other_role_names = ", ".join([r for r in all_roles if r != self.name])
env_desc = f"You are in {self.rc.env.desc} with roles({other_role_names})."
prefix += env_desc
return prefix
@ -338,7 +364,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
if self.recovered and self.rc.state >= 0:
self._set_state(self.rc.state) # action to run from recovered state
self.set_recovered(False) # avoid max_react_loop out of work
self.recovered = False # avoid max_react_loop out of work
return True
prompt = self._get_prefix()
@ -436,7 +462,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
break
# act
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
rsp = await self._act() # 这个rsp是否需要publish_message
rsp = await self._act()
actions_taken += 1
return rsp # return output from the last action
@ -495,6 +521,8 @@ class Role(SerializationMixin, is_polymorphic_base=True):
rsp = await self._act_by_order()
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
rsp = await self._plan_and_act()
else:
raise ValueError(f"Unsupported react mode: {self.rc.react_mode}")
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
return rsp
@ -516,7 +544,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
if not msg.cause_by:
msg.cause_by = UserRequirement
self.put_message(msg)
if not await self._observe():
# If there is no new information, suspend and wait
logger.debug(f"{self._setting}: no news. waiting.")
@ -525,7 +552,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
rsp = await self.react()
# Reset the next action to be taken.
self.rc.todo = None
self.set_todo(None)
# Send the response message to the Environment object to have it relay the message to the subscribers.
self.publish_message(rsp)
return rsp
@ -536,18 +563,34 @@ class Role(SerializationMixin, is_polymorphic_base=True):
return not self.rc.news and not self.rc.todo and self.rc.msg_buffer.empty()
async def think(self) -> Action:
"""The exported `think` function"""
"""
Export SDK API, used by AgentStore RPC.
The exported `think` function
"""
await self._observe() # For compatibility with the old version of the Agent.
await self._think()
return self.rc.todo
async def act(self) -> ActionOutput:
"""The exported `act` function"""
"""
Export SDK API, used by AgentStore RPC.
The exported `act` function
"""
msg = await self._act()
return ActionOutput(content=msg.content, instruct_content=msg.instruct_content)
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
def action_description(self) -> str:
"""
Export SDK API, used by AgentStore RPC and Agent.
AgentStore uses this attribute to display to the user what actions the current role should take.
`Role` provides the default property, and this property should be overridden by children classes if necessary,
as demonstrated by the `Engineer` class.
"""
if self.rc.todo:
if self.rc.todo.desc:
return self.rc.todo.desc
return any_to_name(self.rc.todo)
if self.actions:
return any_to_name(self.actions[0])
return ""

View file

@ -38,5 +38,5 @@ class Sales(Role):
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch)
else:
action = SearchAndSummarize()
self._init_actions([action])
self.set_actions([action])
self._watch([UserRequirement])

Some files were not shown because too many files have changed in this diff Show more