mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge branch 'dev' into code_intepreter
This commit is contained in:
commit
2fcb2a1cfe
282 changed files with 6993 additions and 3210 deletions
34
.github/workflows/build-package.yaml
vendored
Normal file
34
.github/workflows/build-package.yaml
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
name: Build and upload python package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.9'
|
||||
cache: 'pip'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install -e.
|
||||
pip install setuptools wheel twine
|
||||
- name: Set package version
|
||||
run: |
|
||||
export VERSION="${GITHUB_REF#refs/tags/v}"
|
||||
sed -i "s/version=.*/version=\"${VERSION}\",/" setup.py
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
python setup.py bdist_wheel sdist
|
||||
twine upload dist/*
|
||||
1
.github/workflows/unittest.yaml
vendored
1
.github/workflows/unittest.yaml
vendored
|
|
@ -50,6 +50,7 @@ jobs:
|
|||
run: |
|
||||
export ALLOW_OPENAI_API_CALL=0
|
||||
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
|
||||
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
|
||||
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
|
||||
- name: Show coverage report
|
||||
run: |
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -176,5 +176,6 @@ htmlcov.*
|
|||
cov.xml
|
||||
*.dot
|
||||
*.pkl
|
||||
*.faiss
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
|
|
|
|||
30
README.md
30
README.md
|
|
@ -6,16 +6,16 @@ # MetaGPT: The Multi-Agent Framework
|
|||
</p>
|
||||
|
||||
<p align="center">
|
||||
<b>Assign different roles to GPTs to form a collaborative software entity for complex tasks.</b>
|
||||
<b>Assign different roles to GPTs to form a collaborative entity for complex tasks.</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="docs/README_CN.md"><img src="https://img.shields.io/badge/文档-中文版-blue.svg" alt="CN doc"></a>
|
||||
<a href="README.md"><img src="https://img.shields.io/badge/document-English-blue.svg" alt="EN doc"></a>
|
||||
<a href="docs/README_JA.md"><img src="https://img.shields.io/badge/ドキュメント-日本語-blue.svg" alt="JA doc"></a>
|
||||
<a href="https://discord.gg/DYn29wFk9z"><img src="https://dcbadge.vercel.app/api/server/DYn29wFk9z?style=flat" alt="Discord Follow"></a>
|
||||
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-blue.svg" alt="License: MIT"></a>
|
||||
<a href="docs/ROADMAP.md"><img src="https://img.shields.io/badge/ROADMAP-路线图-blue" alt="roadmap"></a>
|
||||
<a href="https://discord.gg/DYn29wFk9z"><img src="https://dcbadge.vercel.app/api/server/DYn29wFk9z?style=flat" alt="Discord Follow"></a>
|
||||
<a href="https://twitter.com/MetaGPT_"><img src="https://img.shields.io/twitter/follow/MetaGPT?style=social" alt="Twitter Follow"></a>
|
||||
</p>
|
||||
|
||||
|
|
@ -25,19 +25,31 @@ # MetaGPT: The Multi-Agent Framework
|
|||
<a href="https://huggingface.co/spaces/deepwisdom/MetaGPT" target="_blank"><img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20-Hugging%20Face-blue?color=blue&logoColor=white" /></a>
|
||||
</p>
|
||||
|
||||
## News
|
||||
🚀 Jan. 16, 2024: Our paper [MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework
|
||||
](https://arxiv.org/abs/2308.00352) accepted for oral presentation **(top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category.
|
||||
|
||||
🚀 Jan. 03, 2024: [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0) released, new features include serialization, upgraded OpenAI package and supported multiple LLM, provided [minimal example for debate](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py) etc.
|
||||
|
||||
🚀 Dec. 15, 2023: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) released, introducing some experimental features such as **incremental development**, **multilingual**, **multiple programming languages**, etc.
|
||||
|
||||
🔥 Nov. 08, 2023: MetaGPT is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html).
|
||||
|
||||
🔥 Sep. 01, 2023: MetaGPT tops GitHub Trending Monthly for the **17th time** in August 2023.
|
||||
|
||||
🌟 Jun. 30, 2023: MetaGPT is now open source.
|
||||
|
||||
🌟 Apr. 24, 2023: First line of MetaGPT code committed.
|
||||
|
||||
## Software Company as Multi-Agent System
|
||||
|
||||
1. MetaGPT takes a **one line requirement** as input and outputs **user stories / competitive analysis / requirements / data structures / APIs / documents, etc.**
|
||||
2. Internally, MetaGPT includes **product managers / architects / project managers / engineers.** It provides the entire process of a **software company along with carefully orchestrated SOPs.**
|
||||
1. `Code = SOP(Team)` is the core philosophy. We materialize SOP and apply it to teams composed of LLMs.
|
||||
|
||||

|
||||
|
||||
<p align="center">Software Company Multi-Role Schematic (Gradually Implementing)</p>
|
||||
|
||||
## News
|
||||
🚀 Jan 03: Here comes [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)! In this version, we added serialization and deserialization of important objects and enabled breakpoint recovery. We upgraded OpenAI package to v1.6.0 and supported Gemini, ZhipuAI, Ollama, OpenLLM, etc. Moreover, we provided extremely simple examples where you need only 7 lines to implement a general election [debate](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py). Check out more details [here](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)!
|
||||
|
||||
|
||||
🚀 Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduced **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or existing codebase. We also launched a whole collection of important features, including **multilingual support** (experimental), multiple **programming languages support** (experimental), **incremental development** (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism!
|
||||
<p align="center">Software Company Multi-Agent Schematic (Gradually Implementing)</p>
|
||||
|
||||
## Install
|
||||
|
||||
|
|
|
|||
|
|
@ -1,149 +0,0 @@
|
|||
# DO NOT MODIFY THIS FILE, create a new key.yaml, define OPENAI_API_KEY.
|
||||
# The configuration of key.yaml has a higher priority and will not enter git
|
||||
|
||||
#### Project Path Setting
|
||||
# WORKSPACE_PATH: "Path for placing output files"
|
||||
|
||||
#### if OpenAI
|
||||
## The official OPENAI_BASE_URL is https://api.openai.com/v1
|
||||
## If the official OPENAI_BASE_URL is not available, we recommend using the [openai-forward](https://github.com/beidongjiedeguang/openai-forward).
|
||||
## Or, you can configure OPENAI_PROXY to access official OPENAI_BASE_URL.
|
||||
OPENAI_BASE_URL: "https://api.openai.com/v1"
|
||||
#OPENAI_PROXY: "http://127.0.0.1:8118"
|
||||
#OPENAI_API_KEY: "YOUR_API_KEY" # set the value to sk-xxx if you host the openai interface for open llm model
|
||||
OPENAI_API_MODEL: "gpt-4-1106-preview"
|
||||
MAX_TOKENS: 4096
|
||||
RPM: 10
|
||||
TIMEOUT: 60 # Timeout for llm invocation
|
||||
#DEFAULT_PROVIDER: openai
|
||||
|
||||
#### if Spark
|
||||
#SPARK_APPID : "YOUR_APPID"
|
||||
#SPARK_API_SECRET : "YOUR_APISecret"
|
||||
#SPARK_API_KEY : "YOUR_APIKey"
|
||||
#DOMAIN : "generalv2"
|
||||
#SPARK_URL : "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
|
||||
#### if Anthropic
|
||||
#ANTHROPIC_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb
|
||||
#OPENAI_API_TYPE: "azure"
|
||||
#OPENAI_BASE_URL: "YOUR_AZURE_ENDPOINT"
|
||||
#OPENAI_API_KEY: "YOUR_AZURE_API_KEY"
|
||||
#OPENAI_API_VERSION: "YOUR_AZURE_API_VERSION"
|
||||
#DEPLOYMENT_NAME: "YOUR_DEPLOYMENT_NAME"
|
||||
|
||||
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
|
||||
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`.
|
||||
#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY"
|
||||
# GEMINI_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### if use self-host open llm model with openai-compatible interface
|
||||
#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1"
|
||||
#OPEN_LLM_API_MODEL: "llama2-13b"
|
||||
#
|
||||
##### if use Fireworks api
|
||||
#FIREWORKS_API_KEY: "YOUR_API_KEY"
|
||||
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
|
||||
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
|
||||
|
||||
#### if use self-host open llm model by ollama
|
||||
# OLLAMA_API_BASE: http://127.0.0.1:11434/api
|
||||
# OLLAMA_API_MODEL: llama2
|
||||
|
||||
#### for Search
|
||||
|
||||
## Supported values: serpapi/google/serper/ddg
|
||||
#SEARCH_ENGINE: serpapi
|
||||
|
||||
## Visit https://serpapi.com/ to get key.
|
||||
#SERPAPI_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
## Visit https://console.cloud.google.com/apis/credentials to get key.
|
||||
#GOOGLE_API_KEY: "YOUR_API_KEY"
|
||||
## Visit https://programmablesearchengine.google.com/controlpanel/create to get id.
|
||||
#GOOGLE_CSE_ID: "YOUR_CSE_ID"
|
||||
|
||||
## Visit https://serper.dev/ to get key.
|
||||
#SERPER_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### for web access
|
||||
|
||||
## Supported values: playwright/selenium
|
||||
#WEB_BROWSER_ENGINE: playwright
|
||||
|
||||
## Supported values: chromium/firefox/webkit, visit https://playwright.dev/python/docs/api/class-browsertype
|
||||
##PLAYWRIGHT_BROWSER_TYPE: chromium
|
||||
|
||||
## Supported values: chrome/firefox/edge/ie, visit https://www.selenium.dev/documentation/webdriver/browsers/
|
||||
# SELENIUM_BROWSER_TYPE: chrome
|
||||
|
||||
#### for TTS
|
||||
|
||||
#AZURE_TTS_SUBSCRIPTION_KEY: "YOUR_API_KEY"
|
||||
#AZURE_TTS_REGION: "eastus"
|
||||
|
||||
#### for OPENAI VISION
|
||||
|
||||
#OPENAI_VISION_MODEL: "YOUR_VISION_MODEL_NAME"
|
||||
#VISION_MAX_TOKENS: 4096
|
||||
|
||||
#### for Stable Diffusion
|
||||
## Use SD service, based on https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
#SD_URL: "YOUR_SD_URL"
|
||||
#SD_T2I_API: "/sdapi/v1/txt2img"
|
||||
|
||||
#### for Execution
|
||||
#LONG_TERM_MEMORY: false
|
||||
|
||||
#### for Mermaid CLI
|
||||
## If you installed mmdc (Mermaid CLI) only for metagpt then enable the following configuration.
|
||||
#PUPPETEER_CONFIG: "./config/puppeteer-config.json"
|
||||
#MMDC: "./node_modules/.bin/mmdc"
|
||||
|
||||
|
||||
### for calc_usage
|
||||
# CALC_USAGE: false
|
||||
|
||||
### for Research
|
||||
# MODEL_FOR_RESEARCHER_SUMMARY: gpt-3.5-turbo
|
||||
# MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
|
||||
|
||||
### choose the engine for mermaid conversion,
|
||||
# default is nodejs, you can change it to playwright,pyppeteer or ink
|
||||
# MERMAID_ENGINE: nodejs
|
||||
|
||||
### browser path for pyppeteer engine, support Chrome, Chromium,MS Edge
|
||||
#PYPPETEER_EXECUTABLE_PATH: "/usr/bin/google-chrome-stable"
|
||||
|
||||
### for repair non-openai LLM's output when parse json-text if PROMPT_FORMAT=json
|
||||
### due to non-openai LLM's output will not always follow the instruction, so here activate a post-process
|
||||
### repair operation on the content extracted from LLM's raw output. Warning, it improves the result but not fix all cases.
|
||||
# REPAIR_LLM_OUTPUT: false
|
||||
|
||||
# PROMPT_FORMAT: json #json or markdown
|
||||
|
||||
### Agent configurations
|
||||
# RAISE_NOT_CONFIG_ERROR: true # "true" if the LLM key is not configured, throw a NotConfiguredException, else "false".
|
||||
# WORKSPACE_PATH_WITH_UID: false # "true" if using `{workspace}/{uid}` as the workspace path; "false" use `{workspace}`.
|
||||
|
||||
### Meta Models
|
||||
#METAGPT_TEXT_TO_IMAGE_MODEL: MODEL_URL
|
||||
|
||||
### S3 config
|
||||
#S3_ACCESS_KEY: "YOUR_S3_ACCESS_KEY"
|
||||
#S3_SECRET_KEY: "YOUR_S3_SECRET_KEY"
|
||||
#S3_ENDPOINT_URL: "YOUR_S3_ENDPOINT_URL"
|
||||
#S3_SECURE: true # true/false
|
||||
#S3_BUCKET: "YOUR_S3_BUCKET"
|
||||
|
||||
### Redis config
|
||||
#REDIS_HOST: "YOUR_REDIS_HOST"
|
||||
#REDIS_PORT: "YOUR_REDIS_PORT"
|
||||
#REDIS_PASSWORD: "YOUR_REDIS_PASSWORD"
|
||||
#REDIS_DB: "YOUR_REDIS_DB_INDEX, str, 0-based"
|
||||
|
||||
# DISABLE_LLM_PROVIDER_CHECK: false
|
||||
3
config/config2.yaml
Normal file
3
config/config2.yaml
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
llm:
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
42
config/config2.yaml.example
Normal file
42
config/config2.yaml.example
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
llm:
|
||||
api_type: "openai"
|
||||
base_url: "YOUR_BASE_URL"
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo-1106" # or gpt-4-1106-preview
|
||||
|
||||
proxy: "YOUR_PROXY"
|
||||
|
||||
search:
|
||||
api_type: "google"
|
||||
api_key: "YOUR_API_KEY"
|
||||
cse_id: "YOUR_CSE_ID"
|
||||
|
||||
mermaid:
|
||||
engine: "pyppeteer"
|
||||
path: "/Applications/Google Chrome.app"
|
||||
|
||||
redis:
|
||||
host: "YOUR_HOST"
|
||||
port: 32582
|
||||
password: "YOUR_PASSWORD"
|
||||
db: "0"
|
||||
|
||||
s3:
|
||||
access_key: "YOUR_ACCESS_KEY"
|
||||
secret_key: "YOUR_SECRET_KEY"
|
||||
endpoint: "YOUR_ENDPOINT"
|
||||
secure: false
|
||||
bucket: "test"
|
||||
|
||||
|
||||
AZURE_TTS_SUBSCRIPTION_KEY: "YOUR_SUBSCRIPTION_KEY"
|
||||
AZURE_TTS_REGION: "eastus"
|
||||
|
||||
IFLYTEK_APP_ID: "YOUR_APP_ID"
|
||||
IFLYTEK_API_KEY: "YOUR_API_KEY"
|
||||
IFLYTEK_API_SECRET: "YOUR_API_SECRET"
|
||||
|
||||
METAGPT_TEXT_TO_IMAGE_MODEL_URL: "YOUR_MODEL_URL"
|
||||
|
||||
PYPPETEER_EXECUTABLE_PATH: "/Applications/Google Chrome.app"
|
||||
|
||||
|
|
@ -9,24 +9,22 @@ ### Short-term Objective
|
|||
|
||||
1. Become the multi-agent framework with the highest ROI.
|
||||
2. Support fully automatic implementation of medium-sized projects (around 2000 lines of code).
|
||||
3. Implement most identified tasks, reaching version 0.5.
|
||||
3. Implement most identified tasks, reaching version 1.0.
|
||||
|
||||
### Tasks
|
||||
|
||||
To reach version v0.5, approximately 70% of the following tasks need to be completed.
|
||||
|
||||
1. Usability
|
||||
1. ~~Release v0.01 pip package to try to solve issues like npm installation (though not necessarily successfully)~~ (v0.3.0)
|
||||
2. Support for overall save and recovery of software companies
|
||||
2. ~~Support for overall save and recovery of software companies~~ (v0.6.0)
|
||||
3. ~~Support human confirmation and modification during the process~~ (v0.3.0) New: Support human confirmation and modification with fewer constrainsts and a more user-friendly interface
|
||||
4. Support process caching: Consider carefully whether to add server caching mechanism
|
||||
5. ~~Resolve occasional failure to follow instruction under current prompts, causing code parsing errors, through stricter system prompts~~ (v0.4.0, with function call)
|
||||
6. Write documentation, describing the current features and usage at all levels (ongoing, continuously adding contents to [documentation site](https://docs.deepwisdom.ai/main/en/guide/get_started/introduction.html))
|
||||
7. ~~Support Docker~~
|
||||
2. Features
|
||||
1. Support a more standard and stable parser (need to analyze the format that the current LLM is better at)
|
||||
2. ~~Establish a separate output queue, differentiated from the message queue~~
|
||||
3. Attempt to atomize all role work, but this may significantly increase token overhead
|
||||
1. ~~Support a more standard and stable parser (need to analyze the format that the current LLM is better at)~~ (v0.5.0)
|
||||
2. ~~Establish a separate output queue, differentiated from the message queue~~ (v0.5.0)
|
||||
3. ~~Attempt to atomize all role work, but this may significantly increase token overhead~~ (v0.5.0)
|
||||
4. Complete the design and implementation of module breakdown
|
||||
5. Support various modes of memory: clearly distinguish between long-term and short-term memory
|
||||
6. Perfect the test role, and carry out necessary interactions with humans
|
||||
|
|
@ -43,10 +41,10 @@ ### Tasks
|
|||
4. Actions
|
||||
1. ~~Implementation: Search~~ (v0.2.1)
|
||||
2. Implementation: Knowledge search, supporting 10+ data formats
|
||||
3. Implementation: Data EDA (expected v0.6.0)
|
||||
4. Implementation: Review
|
||||
5. ~~Implementation~~: Add Document (v0.5.0)
|
||||
6. ~~Implementation~~: Delete Document (v0.5.0)
|
||||
3. Implementation: Data EDA (expected v0.7.0)
|
||||
4. Implementation: Review & Revise (expected v0.7.0)
|
||||
5. ~~Implementation: Add Document~~ (v0.5.0)
|
||||
6. ~~Implementation: Delete Document~~ (v0.5.0)
|
||||
7. Implementation: Self-training
|
||||
8. ~~Implementation: DebugError~~ (v0.2.1)
|
||||
9. Implementation: Generate reliable unit tests based on YAPI
|
||||
|
|
@ -64,15 +62,14 @@ ### Tasks
|
|||
3. ~~Support Playwright apis~~
|
||||
7. Roles
|
||||
1. Perfect the action pool/skill pool for each role
|
||||
2. Red Book blogger
|
||||
3. E-commerce seller
|
||||
4. Data analyst (expected v0.6.0)
|
||||
5. News observer
|
||||
6. ~~Institutional researcher~~ (v0.2.1)
|
||||
2. E-commerce seller
|
||||
3. Data analyst (expected v0.7.0)
|
||||
4. News observer
|
||||
5. ~~Institutional researcher~~ (v0.2.1)
|
||||
8. Evaluation
|
||||
1. Support an evaluation on a game dataset (experimentation done with game agents)
|
||||
2. Reproduce papers, implement full skill acquisition for a single game role, achieving SOTA results (experimentation done with game agents)
|
||||
3. Support an evaluation on a math dataset (expected v0.6.0)
|
||||
3. Support an evaluation on a math dataset (expected v0.7.0)
|
||||
4. Reproduce papers, achieving SOTA results for current mathematical problem solving process
|
||||
9. LLM
|
||||
1. Support Claude underlying API
|
||||
|
|
@ -80,7 +77,7 @@ ### Tasks
|
|||
3. Support streaming version of all APIs
|
||||
4. ~~Make gpt-3.5-turbo available (HARD)~~
|
||||
10. Other
|
||||
1. Clean up existing unused code
|
||||
2. Unify all code styles and establish contribution standards
|
||||
3. Multi-language support
|
||||
4. Multi-programming-language support
|
||||
1. ~~Clean up existing unused code~~
|
||||
2. ~~Unify all code styles and establish contribution standards~~
|
||||
3. ~~Multi-language support~~
|
||||
4. ~~Multi-programming-language support~~
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Author: garylin2099
|
|||
import re
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -48,8 +48,8 @@ class CreateAgent(Action):
|
|||
pattern = r"```python(.*)```"
|
||||
match = re.search(pattern, rsp, re.DOTALL)
|
||||
code_text = match.group(1) if match else ""
|
||||
CONFIG.workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
new_file = CONFIG.workspace_path / "agent_created_agent.py"
|
||||
config.workspace.path.mkdir(parents=True, exist_ok=True)
|
||||
new_file = config.workspace.path / "agent_created_agent.py"
|
||||
new_file.write_text(code_text)
|
||||
return code_text
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ class AgentCreator(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([CreateAgent])
|
||||
self.set_actions([CreateAgent])
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class SimpleCoder(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([SimpleWriteCode])
|
||||
self.set_actions([SimpleWriteCode])
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
|
|
@ -76,7 +76,7 @@ class RunnableCoder(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([SimpleWriteCode, SimpleRunCode])
|
||||
self.set_actions([SimpleWriteCode, SimpleRunCode])
|
||||
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class SimpleCoder(Role):
|
|||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._watch([UserRequirement])
|
||||
self._init_actions([SimpleWriteCode])
|
||||
self.set_actions([SimpleWriteCode])
|
||||
|
||||
|
||||
class SimpleWriteTest(Action):
|
||||
|
|
@ -75,7 +75,7 @@ class SimpleTester(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([SimpleWriteTest])
|
||||
self.set_actions([SimpleWriteTest])
|
||||
# self._watch([SimpleWriteCode])
|
||||
self._watch([SimpleWriteCode, SimpleWriteReview]) # feel free to try this too
|
||||
|
||||
|
|
@ -114,7 +114,7 @@ class SimpleReviewer(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([SimpleWriteReview])
|
||||
self.set_actions([SimpleWriteReview])
|
||||
self._watch([SimpleWriteTest])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class Debator(Role):
|
|||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self._init_actions([SpeakAloud])
|
||||
self.set_actions([SpeakAloud])
|
||||
self._watch([UserRequirement, SpeakAloud])
|
||||
|
||||
async def _observe(self) -> int:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,9 @@ from metagpt.roles import Role
|
|||
from metagpt.team import Team
|
||||
|
||||
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action1.llm.model = "gpt-4-1106-preview"
|
||||
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2.llm.model = "gpt-3.5-turbo-1106"
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
|
||||
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -23,6 +23,10 @@ async def main():
|
|||
# streaming mode, much slower
|
||||
await llm.acompletion_text(hello_msg, stream=True)
|
||||
|
||||
# check completion if exist to test llm complete functions
|
||||
if hasattr(llm, "completion"):
|
||||
logger.info(llm.completion(hello_msg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_PATH
|
||||
from metagpt.document_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -16,7 +16,8 @@ from metagpt.roles import Sales
|
|||
|
||||
|
||||
def get_store():
|
||||
embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url)
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,41 +10,67 @@ from __future__ import annotations
|
|||
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.context_mixin import ContextMixin
|
||||
from metagpt.schema import (
|
||||
CodePlanAndChangeContext,
|
||||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
RunCodeContext,
|
||||
SerializationMixin,
|
||||
TestingContext,
|
||||
)
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class Action(SerializationMixin, is_polymorphic_base=True):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
class Action(SerializationMixin, ContextMixin, BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: str = ""
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
|
||||
i_context: Union[
|
||||
dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, CodePlanAndChangeContext, str, None
|
||||
] = ""
|
||||
prefix: str = "" # aask*时会加上prefix,作为system_message
|
||||
desc: str = "" # for skill manager
|
||||
node: ActionNode = Field(default=None, exclude=True)
|
||||
|
||||
@property
|
||||
def repo(self) -> ProjectRepo:
|
||||
if not self.context.repo:
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
return self.context.repo
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
return self.config.prompt_schema
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return self.config.project_name
|
||||
|
||||
@project_name.setter
|
||||
def project_name(self, value):
|
||||
self.config.project_name = value
|
||||
|
||||
@property
|
||||
def project_path(self):
|
||||
return self.config.project_path
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_name_if_empty(cls, values):
|
||||
if "name" not in values or not values["name"]:
|
||||
values["name"] = cls.__name__
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _init_with_instruction(cls, values):
|
||||
if "instruction" in values:
|
||||
name = values["name"]
|
||||
i = values["instruction"]
|
||||
i = values.pop("instruction")
|
||||
values["node"] = ActionNode(key=name, expected_type=str, instruction=i, example="", schema="raw")
|
||||
return values
|
||||
|
||||
|
|
|
|||
|
|
@ -9,16 +9,30 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
|
|||
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, create_model, model_validator
|
||||
from pydantic import BaseModel, Field, create_model, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.actions.action_outcls_registry import register_action_outcls
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
from metagpt.utils.human_interaction import HumanInteraction
|
||||
|
||||
|
||||
class ReviewMode(Enum):
|
||||
HUMAN = "human"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ReviseMode(Enum):
|
||||
HUMAN = "human" # human revise
|
||||
HUMAN_REVIEW = "human_review" # human-review and auto-revise
|
||||
AUTO = "auto" # auto-review and auto-revise
|
||||
|
||||
|
||||
TAG = "CONTENT"
|
||||
|
||||
|
|
@ -45,6 +59,58 @@ SIMPLE_TEMPLATE = """
|
|||
Follow instructions of nodes, generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
REVIEW_TEMPLATE = """
|
||||
## context
|
||||
Compare the key's value of nodes_output and the corresponding requirements one by one. If a key's value that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
|
||||
|
||||
### nodes_output
|
||||
{nodes_output}
|
||||
|
||||
-----
|
||||
|
||||
## format example
|
||||
[{tag}]
|
||||
{{
|
||||
"key1": "comment1",
|
||||
"key2": "comment2",
|
||||
"keyn": "commentn"
|
||||
}}
|
||||
[/{tag}]
|
||||
|
||||
## nodes: "<node>: <type> # <instruction>"
|
||||
- key1: <class \'str\'> # the first key name of mismatch key
|
||||
- key2: <class \'str\'> # the second key name of mismatch key
|
||||
- keyn: <class \'str\'> # the last key name of mismatch key
|
||||
|
||||
## constraint
|
||||
{constraint}
|
||||
|
||||
## action
|
||||
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
REVISE_TEMPLATE = """
|
||||
## context
|
||||
change the nodes_output key's value to meet its comment and no need to add extra comment.
|
||||
|
||||
### nodes_output
|
||||
{nodes_output}
|
||||
|
||||
-----
|
||||
|
||||
## format example
|
||||
{example}
|
||||
|
||||
## nodes: "<node>: <type> # <instruction>"
|
||||
{instruction}
|
||||
|
||||
## constraint
|
||||
{constraint}
|
||||
|
||||
## action
|
||||
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
|
||||
def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"):
|
||||
markdown_str = ""
|
||||
|
|
@ -105,6 +171,9 @@ class ActionNode:
|
|||
"""增加子ActionNode"""
|
||||
self.children[node.key] = node
|
||||
|
||||
def get_child(self, key: str) -> Union["ActionNode", None]:
|
||||
return self.children.get(key, None)
|
||||
|
||||
def add_children(self, nodes: List["ActionNode"]):
|
||||
"""批量增加子ActionNode"""
|
||||
for node in nodes:
|
||||
|
|
@ -117,11 +186,27 @@ class ActionNode:
|
|||
obj.add_children(nodes)
|
||||
return obj
|
||||
|
||||
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
def get_children_mapping_old(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引"""
|
||||
exclude = exclude or []
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
|
||||
|
||||
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引,支持多级结构"""
|
||||
exclude = exclude or []
|
||||
mapping = {}
|
||||
|
||||
def _get_mapping(node: "ActionNode", prefix: str = ""):
|
||||
for key, child in node.children.items():
|
||||
if key in exclude:
|
||||
continue
|
||||
full_key = f"{prefix}{key}"
|
||||
mapping[full_key] = (child.expected_type, ...)
|
||||
_get_mapping(child, prefix=f"{full_key}.")
|
||||
|
||||
_get_mapping(self)
|
||||
return mapping
|
||||
|
||||
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get self key: type mapping"""
|
||||
return {self.key: (self.expected_type, ...)}
|
||||
|
|
@ -133,6 +218,7 @@ class ActionNode:
|
|||
return {} if exclude and self.key in exclude else self.get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
@register_action_outcls
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
|
||||
|
|
@ -152,6 +238,11 @@ class ActionNode:
|
|||
new_class = create_model(class_name, __validators__=validators, **mapping)
|
||||
return new_class
|
||||
|
||||
def create_class(self, mode: str = "auto", class_name: str = None, exclude=None):
|
||||
class_name = class_name if class_name else f"{self.key}_AN"
|
||||
mapping = self.get_mapping(mode=mode, exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def create_children_class(self, exclude=None):
|
||||
"""使用object内有的字段直接生成model_class"""
|
||||
class_name = f"{self.key}_AN"
|
||||
|
|
@ -186,6 +277,25 @@ class ActionNode:
|
|||
|
||||
return node_dict
|
||||
|
||||
def update_instruct_content(self, incre_data: dict[str, Any]):
|
||||
assert self.instruct_content
|
||||
origin_sc_dict = self.instruct_content.model_dump()
|
||||
origin_sc_dict.update(incre_data)
|
||||
output_class = self.create_class()
|
||||
self.instruct_content = output_class(**origin_sc_dict)
|
||||
|
||||
def keys(self, mode: str = "auto") -> list:
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
keys = []
|
||||
else:
|
||||
keys = [self.key]
|
||||
if mode == "root":
|
||||
return keys
|
||||
|
||||
for _, child_node in self.children.items():
|
||||
keys.append(child_node.key)
|
||||
return keys
|
||||
|
||||
def compile_to(self, i: Dict, schema, kv_sep) -> str:
|
||||
if schema == "json":
|
||||
return json.dumps(i, indent=4)
|
||||
|
|
@ -262,7 +372,7 @@ class ActionNode:
|
|||
output_data_mapping: dict,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
schema="markdown", # compatible to original format
|
||||
timeout=CONFIG.timeout,
|
||||
timeout=3,
|
||||
) -> (str, BaseModel):
|
||||
"""Use ActionOutput to wrap the output of aask"""
|
||||
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
|
||||
|
|
@ -294,7 +404,7 @@ class ActionNode:
|
|||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
|
||||
async def simple_fill(self, schema, mode, timeout=3, exclude=None):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
|
|
@ -309,7 +419,7 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
|
|
@ -343,7 +453,241 @@ class ActionNode:
|
|||
if exclude and i.key in exclude:
|
||||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.dict())
|
||||
tmp.update(child.instruct_content.model_dump())
|
||||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
|
||||
async def human_review(self) -> dict[str, str]:
|
||||
review_comments = HumanInteraction().interact_with_instruct_content(
|
||||
instruct_content=self.instruct_content, interact_type="review"
|
||||
)
|
||||
|
||||
return review_comments
|
||||
|
||||
def _makeup_nodes_output_with_req(self) -> dict[str, str]:
|
||||
instruct_content_dict = self.instruct_content.model_dump()
|
||||
nodes_output = {}
|
||||
for key, value in instruct_content_dict.items():
|
||||
child = self.get_child(key)
|
||||
nodes_output[key] = {"value": value, "requirement": child.instruction if child else self.instruction}
|
||||
return nodes_output
|
||||
|
||||
async def auto_review(self, template: str = REVIEW_TEMPLATE) -> dict[str, str]:
|
||||
"""use key's output value and its instruction to review the modification comment"""
|
||||
nodes_output = self._makeup_nodes_output_with_req()
|
||||
"""nodes_output format:
|
||||
{
|
||||
"key": {"value": "output value", "requirement": "key instruction"}
|
||||
}
|
||||
"""
|
||||
if not nodes_output:
|
||||
return dict()
|
||||
|
||||
prompt = template.format(
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
|
||||
tag=TAG,
|
||||
constraint=FORMAT_CONSTRAINT,
|
||||
prompt_schema="json",
|
||||
)
|
||||
|
||||
content = await self.llm.aask(prompt)
|
||||
# Extract the dict of mismatch key and its comment. Due to the mismatch keys are unknown, here use the keys
|
||||
# of ActionNode to judge if exist in `content` and then follow the `data_mapping` method to create model class.
|
||||
keys = self.keys()
|
||||
include_keys = []
|
||||
for key in keys:
|
||||
if f'"{key}":' in content:
|
||||
include_keys.append(key)
|
||||
if not include_keys:
|
||||
return dict()
|
||||
|
||||
exclude_keys = list(set(keys).difference(include_keys))
|
||||
output_class_name = f"{self.key}_AN_REVIEW"
|
||||
output_class = self.create_class(class_name=output_class_name, exclude=exclude_keys)
|
||||
parsed_data = llm_output_postprocess(
|
||||
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
|
||||
)
|
||||
instruct_content = output_class(**parsed_data)
|
||||
return instruct_content.model_dump()
|
||||
|
||||
async def simple_review(self, review_mode: ReviewMode = ReviewMode.AUTO):
|
||||
# generate review comments
|
||||
if review_mode == ReviewMode.HUMAN:
|
||||
review_comments = await self.human_review()
|
||||
else:
|
||||
review_comments = await self.auto_review()
|
||||
|
||||
if not review_comments:
|
||||
logger.warning("There are no review comments")
|
||||
return review_comments
|
||||
|
||||
async def review(self, strgy: str = "simple", review_mode: ReviewMode = ReviewMode.AUTO):
|
||||
"""only give the review comment of each exist and mismatch key
|
||||
|
||||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
"""
|
||||
if not hasattr(self, "llm"):
|
||||
raise RuntimeError("use `review` after `fill`")
|
||||
assert review_mode in ReviewMode
|
||||
assert self.instruct_content, 'review only support with `schema != "raw"`'
|
||||
|
||||
if strgy == "simple":
|
||||
review_comments = await self.simple_review(review_mode)
|
||||
elif strgy == "complex":
|
||||
# review each child node one-by-one
|
||||
review_comments = {}
|
||||
for _, child in self.children.items():
|
||||
child_review_comment = await child.simple_review(review_mode)
|
||||
review_comments.update(child_review_comment)
|
||||
|
||||
return review_comments
|
||||
|
||||
async def human_revise(self) -> dict[str, str]:
|
||||
review_contents = HumanInteraction().interact_with_instruct_content(
|
||||
instruct_content=self.instruct_content, mapping=self.get_mapping(mode="auto"), interact_type="revise"
|
||||
)
|
||||
# re-fill the ActionNode
|
||||
self.update_instruct_content(review_contents)
|
||||
return review_contents
|
||||
|
||||
def _makeup_nodes_output_with_comment(self, review_comments: dict[str, str]) -> dict[str, str]:
|
||||
instruct_content_dict = self.instruct_content.model_dump()
|
||||
nodes_output = {}
|
||||
for key, value in instruct_content_dict.items():
|
||||
if key in review_comments:
|
||||
nodes_output[key] = {"value": value, "comment": review_comments[key]}
|
||||
return nodes_output
|
||||
|
||||
async def auto_revise(
|
||||
self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE
|
||||
) -> dict[str, str]:
|
||||
"""revise the value of incorrect keys"""
|
||||
# generate review comments
|
||||
if revise_mode == ReviseMode.AUTO:
|
||||
review_comments: dict = await self.auto_review()
|
||||
elif revise_mode == ReviseMode.HUMAN_REVIEW:
|
||||
review_comments: dict = await self.human_review()
|
||||
|
||||
include_keys = list(review_comments.keys())
|
||||
|
||||
# generate revise content, two-steps
|
||||
# step1, find the needed revise keys from review comments to makeup prompt template
|
||||
nodes_output = self._makeup_nodes_output_with_comment(review_comments)
|
||||
keys = self.keys()
|
||||
exclude_keys = list(set(keys).difference(include_keys))
|
||||
example = self.compile_example(schema="json", mode="auto", tag=TAG, exclude=exclude_keys)
|
||||
instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys)
|
||||
|
||||
prompt = template.format(
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
|
||||
example=example,
|
||||
instruction=instruction,
|
||||
constraint=FORMAT_CONSTRAINT,
|
||||
prompt_schema="json",
|
||||
)
|
||||
|
||||
# step2, use `_aask_v1` to get revise structure result
|
||||
output_mapping = self.get_mapping(mode="auto", exclude=exclude_keys)
|
||||
output_class_name = f"{self.key}_AN_REVISE"
|
||||
content, scontent = await self._aask_v1(
|
||||
prompt=prompt, output_class_name=output_class_name, output_data_mapping=output_mapping, schema="json"
|
||||
)
|
||||
|
||||
# re-fill the ActionNode
|
||||
sc_dict = scontent.model_dump()
|
||||
self.update_instruct_content(sc_dict)
|
||||
return sc_dict
|
||||
|
||||
async def simple_revise(self, revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
|
||||
if revise_mode == ReviseMode.HUMAN:
|
||||
revise_contents = await self.human_revise()
|
||||
else:
|
||||
revise_contents = await self.auto_revise(revise_mode)
|
||||
|
||||
return revise_contents
|
||||
|
||||
async def revise(self, strgy: str = "simple", revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
|
||||
"""revise the content of ActionNode and update the instruct_content
|
||||
|
||||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
"""
|
||||
if not hasattr(self, "llm"):
|
||||
raise RuntimeError("use `revise` after `fill`")
|
||||
assert revise_mode in ReviseMode
|
||||
assert self.instruct_content, 'revise only support with `schema != "raw"`'
|
||||
|
||||
if strgy == "simple":
|
||||
revise_contents = await self.simple_revise(revise_mode)
|
||||
elif strgy == "complex":
|
||||
# revise each child node one-by-one
|
||||
revise_contents = {}
|
||||
for _, child in self.children.items():
|
||||
child_revise_content = await child.simple_revise(revise_mode)
|
||||
revise_contents.update(child_revise_content)
|
||||
self.update_instruct_content(revise_contents)
|
||||
|
||||
return revise_contents
|
||||
|
||||
@classmethod
|
||||
def from_pydantic(cls, model: Type[BaseModel], key: str = None):
|
||||
"""
|
||||
Creates an ActionNode tree from a Pydantic model.
|
||||
|
||||
Args:
|
||||
model (Type[BaseModel]): The Pydantic model to convert.
|
||||
|
||||
Returns:
|
||||
ActionNode: The root node of the created ActionNode tree.
|
||||
"""
|
||||
key = key or model.__name__
|
||||
root_node = cls(key=model.__name__, expected_type=Type[model], instruction="", example="")
|
||||
|
||||
for field_name, field_model in model.model_fields.items():
|
||||
# Extracting field details
|
||||
expected_type = field_model.annotation
|
||||
instruction = field_model.description or ""
|
||||
example = field_model.default
|
||||
|
||||
# Check if the field is a Pydantic model itself.
|
||||
# Use isinstance to avoid typing.List, typing.Dict, etc. (they are instances of type, not subclasses)
|
||||
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
|
||||
# Recursively process the nested model
|
||||
child_node = cls.from_pydantic(expected_type, key=field_name)
|
||||
else:
|
||||
child_node = cls(key=field_name, expected_type=expected_type, instruction=instruction, example=example)
|
||||
|
||||
root_node.add_child(child_node)
|
||||
|
||||
return root_node
|
||||
|
||||
|
||||
class ToolUse(BaseModel):
|
||||
tool_name: str = Field(default="a", description="tool name", examples=[])
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_id: int = Field(default="1", description="task id", examples=[1, 2, 3])
|
||||
name: str = Field(default="Get data from ...", description="task name", examples=[])
|
||||
dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3])
|
||||
tool: ToolUse = Field(default=ToolUse(), description="tool use", examples=[])
|
||||
|
||||
|
||||
class Tasks(BaseModel):
|
||||
tasks: List[Task] = Field(default=[], description="tasks", examples=[])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
node = ActionNode.from_pydantic(Tasks)
|
||||
print("Tasks")
|
||||
print(Tasks.model_json_schema())
|
||||
print("Task")
|
||||
print(Task.model_json_schema())
|
||||
print(node)
|
||||
prompt = node.compile(context="")
|
||||
node.create_children_class()
|
||||
print(prompt)
|
||||
|
|
|
|||
42
metagpt/actions/action_outcls_registry.py
Normal file
42
metagpt/actions/action_outcls_registry.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : registry to store Dynamic Model from ActionNode.create_model_class to keep it as same Class
|
||||
# with same class name and mapping
|
||||
|
||||
from functools import wraps
|
||||
|
||||
action_outcls_registry = dict()
|
||||
|
||||
|
||||
def register_action_outcls(func):
|
||||
"""
|
||||
Due to `create_model` return different Class even they have same class name and mapping.
|
||||
In order to do a comparison, use outcls_id to identify same Class with same class name and field definition
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorater(*args, **kwargs):
|
||||
"""
|
||||
arr example
|
||||
[<class 'metagpt.actions.action_node.ActionNode'>, 'test', {'field': (str, Ellipsis)}]
|
||||
"""
|
||||
arr = list(args) + list(kwargs.values())
|
||||
"""
|
||||
outcls_id example
|
||||
"<class 'metagpt.actions.action_node.ActionNode'>_test_{'field': (str, Ellipsis)}"
|
||||
"""
|
||||
for idx, item in enumerate(arr):
|
||||
if isinstance(item, dict):
|
||||
arr[idx] = dict(sorted(item.items()))
|
||||
outcls_id = "_".join([str(i) for i in arr])
|
||||
# eliminate typing influence
|
||||
outcls_id = outcls_id.replace("typing.List", "list").replace("typing.Dict", "dict")
|
||||
|
||||
if outcls_id in action_outcls_registry:
|
||||
return action_outcls_registry[outcls_id]
|
||||
|
||||
out_cls = func(*args, **kwargs)
|
||||
action_outcls_registry[outcls_id] = out_cls
|
||||
return out_cls
|
||||
|
||||
return decorater
|
||||
|
|
@ -13,12 +13,9 @@ import re
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
NOTICE
|
||||
|
|
@ -49,13 +46,10 @@ Now you should start rewriting the code:
|
|||
|
||||
|
||||
class DebugError(Action):
|
||||
name: str = "DebugError"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await FileRepository.get_file(
|
||||
filename=self.context.output_filename, relative_path=TEST_OUTPUTS_FILE_REPO
|
||||
)
|
||||
output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename)
|
||||
if not output_doc:
|
||||
return ""
|
||||
output_detail = RunCodeResult.loads(output_doc.content)
|
||||
|
|
@ -64,15 +58,13 @@ class DebugError(Action):
|
|||
if matches:
|
||||
return ""
|
||||
|
||||
logger.info(f"Debug and rewrite {self.context.test_filename}")
|
||||
code_doc = await FileRepository.get_file(
|
||||
filename=self.context.code_filename, relative_path=CONFIG.src_workspace
|
||||
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
|
||||
code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get(
|
||||
filename=self.i_context.code_filename
|
||||
)
|
||||
if not code_doc:
|
||||
return ""
|
||||
test_doc = await FileRepository.get_file(
|
||||
filename=self.context.test_filename, relative_path=TEST_CODES_FILE_REPO
|
||||
)
|
||||
test_doc = await self.repo.tests.get(filename=self.i_context.test_filename)
|
||||
if not test_doc:
|
||||
return ""
|
||||
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)
|
||||
|
|
|
|||
|
|
@ -14,18 +14,17 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.design_api_an import DESIGN_API_NODE
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
DATA_API_DESIGN_FILE_REPO,
|
||||
PRDS_FILE_REPO,
|
||||
SEQ_FLOW_FILE_REPO,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
SYSTEM_DESIGN_PDF_FILE_REPO,
|
||||
from metagpt.actions.design_api_an import (
|
||||
DATA_STRUCTURES_AND_INTERFACES,
|
||||
DESIGN_API_NODE,
|
||||
PROGRAM_CALL_FLOW,
|
||||
REFINED_DATA_STRUCTURES_AND_INTERFACES,
|
||||
REFINED_DESIGN_NODE,
|
||||
REFINED_PROGRAM_CALL_FLOW,
|
||||
)
|
||||
from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
||||
NEW_REQ_TEMPLATE = """
|
||||
|
|
@ -39,36 +38,30 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
class WriteDesign(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
desc: str = (
|
||||
"Based on the PRD, think about the system design, and design the corresponding APIs, "
|
||||
"data structures, library tables, processes, and paths. Please provide your design, feedback "
|
||||
"clearly and in detail."
|
||||
)
|
||||
|
||||
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
|
||||
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
|
||||
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
changed_prds = prds_file_repo.changed_files
|
||||
async def run(self, with_messages: Message, schema: str = None):
|
||||
# Use `git status` to identify which PRD documents have been modified in the `docs/prd` directory.
|
||||
changed_prds = self.repo.docs.prd.changed_files
|
||||
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
|
||||
# changes.
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
changed_system_designs = system_design_file_repo.changed_files
|
||||
changed_system_designs = self.repo.docs.system_design.changed_files
|
||||
|
||||
# For those PRDs and design documents that have undergone changes, regenerate the design content.
|
||||
changed_files = Documents()
|
||||
for filename in changed_prds.keys():
|
||||
doc = await self._update_system_design(
|
||||
filename=filename, prds_file_repo=prds_file_repo, system_design_file_repo=system_design_file_repo
|
||||
)
|
||||
doc = await self._update_system_design(filename=filename)
|
||||
changed_files.docs[filename] = doc
|
||||
|
||||
for filename in changed_system_designs.keys():
|
||||
if filename in changed_files.docs:
|
||||
continue
|
||||
doc = await self._update_system_design(
|
||||
filename=filename, prds_file_repo=prds_file_repo, system_design_file_repo=system_design_file_repo
|
||||
)
|
||||
doc = await self._update_system_design(filename=filename)
|
||||
changed_files.docs[filename] = doc
|
||||
if not changed_files.docs:
|
||||
logger.info("Nothing has changed.")
|
||||
|
|
@ -76,61 +69,52 @@ class WriteDesign(Action):
|
|||
# leaving room for global optimization in subsequent steps.
|
||||
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
|
||||
|
||||
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
async def _new_system_design(self, context):
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm)
|
||||
return node
|
||||
|
||||
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
|
||||
async def _merge(self, prd_doc, system_design_doc):
|
||||
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm)
|
||||
system_design_doc.content = node.instruct_content.model_dump_json()
|
||||
return system_design_doc
|
||||
|
||||
async def _update_system_design(self, filename, prds_file_repo, system_design_file_repo) -> Document:
|
||||
prd = await prds_file_repo.get(filename)
|
||||
old_system_design_doc = await system_design_file_repo.get(filename)
|
||||
async def _update_system_design(self, filename) -> Document:
|
||||
prd = await self.repo.docs.prd.get(filename)
|
||||
old_system_design_doc = await self.repo.docs.system_design.get(filename)
|
||||
if not old_system_design_doc:
|
||||
system_design = await self._new_system_design(context=prd.content)
|
||||
doc = Document(
|
||||
root_path=SYSTEM_DESIGN_FILE_REPO,
|
||||
doc = await self.repo.docs.system_design.save(
|
||||
filename=filename,
|
||||
content=system_design.instruct_content.model_dump_json(),
|
||||
dependencies={prd.root_relative_path},
|
||||
)
|
||||
else:
|
||||
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
|
||||
await system_design_file_repo.save(
|
||||
filename=filename, content=doc.content, dependencies={prd.root_relative_path}
|
||||
)
|
||||
await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
|
||||
await self._save_data_api_design(doc)
|
||||
await self._save_seq_flow(doc)
|
||||
await self._save_pdf(doc)
|
||||
await self.repo.resources.system_design.save_pdf(doc=doc)
|
||||
return doc
|
||||
|
||||
@staticmethod
|
||||
async def _save_data_api_design(design_doc):
|
||||
async def _save_data_api_design(self, design_doc):
|
||||
m = json.loads(design_doc.content)
|
||||
data_api_design = m.get("Data structures and interfaces")
|
||||
data_api_design = m.get(DATA_STRUCTURES_AND_INTERFACES.key) or m.get(REFINED_DATA_STRUCTURES_AND_INTERFACES.key)
|
||||
if not data_api_design:
|
||||
return
|
||||
pathname = CONFIG.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
|
||||
await WriteDesign._save_mermaid_file(data_api_design, pathname)
|
||||
pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
|
||||
await self._save_mermaid_file(data_api_design, pathname)
|
||||
logger.info(f"Save class view to {str(pathname)}")
|
||||
|
||||
@staticmethod
|
||||
async def _save_seq_flow(design_doc):
|
||||
async def _save_seq_flow(self, design_doc):
|
||||
m = json.loads(design_doc.content)
|
||||
seq_flow = m.get("Program call flow")
|
||||
seq_flow = m.get(PROGRAM_CALL_FLOW.key) or m.get(REFINED_PROGRAM_CALL_FLOW.key)
|
||||
if not seq_flow:
|
||||
return
|
||||
pathname = CONFIG.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
|
||||
await WriteDesign._save_mermaid_file(seq_flow, pathname)
|
||||
pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
|
||||
await self._save_mermaid_file(seq_flow, pathname)
|
||||
logger.info(f"Saving sequence flow to {str(pathname)}")
|
||||
|
||||
@staticmethod
|
||||
async def _save_pdf(design_doc):
|
||||
await FileRepository.save_as(doc=design_doc, with_suffix=".md", relative_path=SYSTEM_DESIGN_PDF_FILE_REPO)
|
||||
|
||||
@staticmethod
|
||||
async def _save_mermaid_file(data: str, pathname: Path):
|
||||
async def _save_mermaid_file(self, data: str, pathname: Path):
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
await mermaid_to_file(data, pathname)
|
||||
await mermaid_to_file(self.config.mermaid_engine, data, pathname)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.mermaid import MMC1, MMC2
|
||||
|
||||
IMPLEMENTATION_APPROACH = ActionNode(
|
||||
|
|
@ -17,6 +18,15 @@ IMPLEMENTATION_APPROACH = ActionNode(
|
|||
example="We will ...",
|
||||
)
|
||||
|
||||
REFINED_IMPLEMENTATION_APPROACH = ActionNode(
|
||||
key="Refined Implementation Approach",
|
||||
expected_type=str,
|
||||
instruction="Update and extend the original implementation approach to reflect the evolving challenges and "
|
||||
"requirements due to incremental development. Outline the steps involved in the implementation process with the "
|
||||
"detailed strategies.",
|
||||
example="We will refine ...",
|
||||
)
|
||||
|
||||
PROJECT_NAME = ActionNode(
|
||||
key="Project name", expected_type=str, instruction="The project name with underline", example="game_2048"
|
||||
)
|
||||
|
|
@ -28,6 +38,14 @@ FILE_LIST = ActionNode(
|
|||
example=["main.py", "game.py"],
|
||||
)
|
||||
|
||||
REFINED_FILE_LIST = ActionNode(
|
||||
key="Refined File list",
|
||||
expected_type=List[str],
|
||||
instruction="Update and expand the original file list including only relative paths. Up to 2 files can be added."
|
||||
"Ensure that the refined file list reflects the evolving structure of the project.",
|
||||
example=["main.py", "game.py", "new_feature.py"],
|
||||
)
|
||||
|
||||
DATA_STRUCTURES_AND_INTERFACES = ActionNode(
|
||||
key="Data structures and interfaces",
|
||||
expected_type=str,
|
||||
|
|
@ -37,6 +55,16 @@ DATA_STRUCTURES_AND_INTERFACES = ActionNode(
|
|||
example=MMC1,
|
||||
)
|
||||
|
||||
REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode(
|
||||
key="Refined Data structures and interfaces",
|
||||
expected_type=str,
|
||||
instruction="Update and extend the existing mermaid classDiagram code syntax to incorporate new classes, "
|
||||
"methods (including __init__), and functions with precise type annotations. Delineate additional "
|
||||
"relationships between classes, ensuring clarity and adherence to PEP8 standards."
|
||||
"Retain content that is not related to incremental development but important for consistency and clarity.",
|
||||
example=MMC1,
|
||||
)
|
||||
|
||||
PROGRAM_CALL_FLOW = ActionNode(
|
||||
key="Program call flow",
|
||||
expected_type=str,
|
||||
|
|
@ -45,6 +73,16 @@ PROGRAM_CALL_FLOW = ActionNode(
|
|||
example=MMC2,
|
||||
)
|
||||
|
||||
REFINED_PROGRAM_CALL_FLOW = ActionNode(
|
||||
key="Refined Program call flow",
|
||||
expected_type=str,
|
||||
instruction="Extend the existing sequenceDiagram code syntax with detailed information, accurately covering the"
|
||||
"CRUD and initialization of each object. Ensure correct syntax usage and reflect the incremental changes introduced"
|
||||
"in the classes and API defined above. "
|
||||
"Retain content that is not related to incremental development but important for consistency and clarity.",
|
||||
example=MMC2,
|
||||
)
|
||||
|
||||
ANYTHING_UNCLEAR = ActionNode(
|
||||
key="Anything UNCLEAR",
|
||||
expected_type=str,
|
||||
|
|
@ -61,4 +99,24 @@ NODES = [
|
|||
ANYTHING_UNCLEAR,
|
||||
]
|
||||
|
||||
REFINED_NODES = [
|
||||
REFINED_IMPLEMENTATION_APPROACH,
|
||||
REFINED_FILE_LIST,
|
||||
REFINED_DATA_STRUCTURES_AND_INTERFACES,
|
||||
REFINED_PROGRAM_CALL_FLOW,
|
||||
ANYTHING_UNCLEAR,
|
||||
]
|
||||
|
||||
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
|
||||
REFINED_DESIGN_NODE = ActionNode.from_children("RefinedDesignAPI", REFINED_NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = DESIGN_API_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
prompt = REFINED_DESIGN_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.actions.action import Action
|
|||
|
||||
class DesignReview(Action):
|
||||
name: str = "DesignReview"
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
|
||||
async def run(self, prd, api_design):
|
||||
prompt = (
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.schema import Message
|
|||
|
||||
class ExecuteTask(Action):
|
||||
name: str = "ExecuteTask"
|
||||
context: list[Message] = []
|
||||
i_context: list[Message] = []
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/12 17:45
|
||||
@Author : fisherdeng
|
||||
@File : generate_questions.py
|
||||
"""
|
||||
from metagpt.actions import Action
|
||||
|
|
@ -23,5 +21,5 @@ class GenerateQuestions(Action):
|
|||
|
||||
name: str = "GenerateQuestions"
|
||||
|
||||
async def run(self, context):
|
||||
async def run(self, context) -> ActionNode:
|
||||
return await QUESTIONS.fill(context=context, llm=self.llm)
|
||||
|
|
|
|||
|
|
@ -16,17 +16,14 @@ from typing import Optional
|
|||
|
||||
import pandas as pd
|
||||
from paddleocr import PaddleOCR
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.const import INVOICE_OCR_TABLE_PATH
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.invoice_ocr import (
|
||||
EXTRACT_OCR_MAIN_INFO_PROMPT,
|
||||
REPLY_OCR_QUESTION_PROMPT,
|
||||
)
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.file import File
|
||||
|
||||
|
|
@ -41,7 +38,7 @@ class InvoiceOCR(Action):
|
|||
"""
|
||||
|
||||
name: str = "InvoiceOCR"
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
async def _check_file_type(file_path: Path) -> str:
|
||||
|
|
@ -132,8 +129,7 @@ class GenerateTable(Action):
|
|||
"""
|
||||
|
||||
name: str = "GenerateTable"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
i_context: Optional[str] = None
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
|
||||
|
|
@ -176,9 +172,6 @@ class ReplyQuestion(Action):
|
|||
|
||||
"""
|
||||
|
||||
name: str = "ReplyQuestion"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:
|
||||
|
|
|
|||
|
|
@ -12,39 +12,41 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.schema import Document
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class PrepareDocuments(Action):
|
||||
"""PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt."""
|
||||
|
||||
name: str = "PrepareDocuments"
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.context.config
|
||||
|
||||
def _init_repo(self):
|
||||
"""Initialize the Git environment."""
|
||||
if not CONFIG.project_path:
|
||||
name = CONFIG.project_name or FileRepository.new_filename()
|
||||
path = Path(CONFIG.workspace_path) / name
|
||||
if not self.config.project_path:
|
||||
name = self.config.project_name or FileRepository.new_filename()
|
||||
path = Path(self.config.workspace.path) / name
|
||||
else:
|
||||
path = Path(CONFIG.project_path)
|
||||
if path.exists() and not CONFIG.inc:
|
||||
path = Path(self.config.project_path)
|
||||
if path.exists() and not self.config.inc:
|
||||
shutil.rmtree(path)
|
||||
CONFIG.project_path = path
|
||||
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
self.config.project_path = path
|
||||
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
"""Create and initialize the workspace folder, initialize the Git environment."""
|
||||
self._init_repo()
|
||||
|
||||
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
|
||||
doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
|
||||
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=doc.content, relative_path=DOCS_FILE_REPO)
|
||||
|
||||
doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
|
||||
# Send a Message notification to the WritePRD action, instructing it to process requirements using
|
||||
# `docs/requirement.txt` and `docs/prds/`.
|
||||
# `docs/requirement.txt` and `docs/prd/`.
|
||||
return ActionOutput(content=doc.content, instruct_content=doc)
|
||||
|
|
|
|||
|
|
@ -13,23 +13,16 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import PM_NODE
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
PACKAGE_REQUIREMENTS_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
TASK_PDF_FILE_REPO,
|
||||
)
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE
|
||||
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document, Documents
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
NEW_REQ_TEMPLATE = """
|
||||
### Legacy Content
|
||||
{old_tasks}
|
||||
{old_task}
|
||||
|
||||
### New Requirements
|
||||
{context}
|
||||
|
|
@ -38,30 +31,23 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
class WriteTasks(Action):
|
||||
name: str = "CreateTasks"
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
changed_system_designs = system_design_file_repo.changed_files
|
||||
|
||||
tasks_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
|
||||
changed_tasks = tasks_file_repo.changed_files
|
||||
async def run(self, with_messages):
|
||||
changed_system_designs = self.repo.docs.system_design.changed_files
|
||||
changed_tasks = self.repo.docs.task.changed_files
|
||||
change_files = Documents()
|
||||
# Rewrite the system designs that have undergone changes based on the git head diff under
|
||||
# `docs/system_designs/`.
|
||||
for filename in changed_system_designs:
|
||||
task_doc = await self._update_tasks(
|
||||
filename=filename, system_design_file_repo=system_design_file_repo, tasks_file_repo=tasks_file_repo
|
||||
)
|
||||
task_doc = await self._update_tasks(filename=filename)
|
||||
change_files.docs[filename] = task_doc
|
||||
|
||||
# Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`.
|
||||
for filename in changed_tasks:
|
||||
if filename in change_files.docs:
|
||||
continue
|
||||
task_doc = await self._update_tasks(
|
||||
filename=filename, system_design_file_repo=system_design_file_repo, tasks_file_repo=tasks_file_repo
|
||||
)
|
||||
task_doc = await self._update_tasks(filename=filename)
|
||||
change_files.docs[filename] = task_doc
|
||||
|
||||
if not change_files.docs:
|
||||
|
|
@ -70,39 +56,36 @@ class WriteTasks(Action):
|
|||
# global optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
|
||||
async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo):
|
||||
system_design_doc = await system_design_file_repo.get(filename)
|
||||
task_doc = await tasks_file_repo.get(filename)
|
||||
async def _update_tasks(self, filename):
|
||||
system_design_doc = await self.repo.docs.system_design.get(filename)
|
||||
task_doc = await self.repo.docs.task.get(filename)
|
||||
if task_doc:
|
||||
task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc)
|
||||
await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path})
|
||||
else:
|
||||
rsp = await self._run_new_tasks(context=system_design_doc.content)
|
||||
task_doc = Document(
|
||||
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.model_dump_json()
|
||||
task_doc = await self.repo.docs.task.save(
|
||||
filename=filename,
|
||||
content=rsp.instruct_content.model_dump_json(),
|
||||
dependencies={system_design_doc.root_relative_path},
|
||||
)
|
||||
await tasks_file_repo.save(
|
||||
filename=filename, content=task_doc.content, dependencies={system_design_doc.root_relative_path}
|
||||
)
|
||||
await self._update_requirements(task_doc)
|
||||
await self._save_pdf(task_doc=task_doc)
|
||||
return task_doc
|
||||
|
||||
async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema):
|
||||
node = await PM_NODE.fill(context, self.llm, schema)
|
||||
async def _run_new_tasks(self, context):
|
||||
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
return node
|
||||
|
||||
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
|
||||
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
|
||||
node = await PM_NODE.fill(context, self.llm, schema)
|
||||
async def _merge(self, system_design_doc, task_doc) -> Document:
|
||||
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content)
|
||||
node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
task_doc.content = node.instruct_content.model_dump_json()
|
||||
return task_doc
|
||||
|
||||
@staticmethod
|
||||
async def _update_requirements(doc):
|
||||
async def _update_requirements(self, doc):
|
||||
m = json.loads(doc.content)
|
||||
packages = set(m.get("Required Python third-party packages", set()))
|
||||
file_repo = CONFIG.git_repo.new_file_repository()
|
||||
requirement_doc = await file_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
packages = set(m.get("Required Python packages", set()))
|
||||
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
if not requirement_doc:
|
||||
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
|
||||
lines = requirement_doc.content.splitlines()
|
||||
|
|
@ -110,8 +93,4 @@ class WriteTasks(Action):
|
|||
if pkg == "":
|
||||
continue
|
||||
packages.add(pkg)
|
||||
await file_repo.save(PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
|
||||
|
||||
@staticmethod
|
||||
async def _save_pdf(task_doc):
|
||||
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
|
||||
await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
|
||||
|
|
|
|||
|
|
@ -35,6 +35,20 @@ LOGIC_ANALYSIS = ActionNode(
|
|||
],
|
||||
)
|
||||
|
||||
REFINED_LOGIC_ANALYSIS = ActionNode(
|
||||
key="Refined Logic Analysis",
|
||||
expected_type=List[List[str]],
|
||||
instruction="Review and refine the logic analysis by merging the Legacy Content and Incremental Content. "
|
||||
"Provide a comprehensive list of files with classes/methods/functions to be implemented or modified incrementally. "
|
||||
"Include dependency analysis, consider potential impacts on existing code, and document necessary imports.",
|
||||
example=[
|
||||
["game.py", "Contains Game class and ... functions"],
|
||||
["main.py", "Contains main function, from game import Game"],
|
||||
["new_feature.py", "Introduces NewFeature class and related functions"],
|
||||
["utils.py", "Modifies existing utility functions to support incremental changes"],
|
||||
],
|
||||
)
|
||||
|
||||
TASK_LIST = ActionNode(
|
||||
key="Task list",
|
||||
expected_type=List[str],
|
||||
|
|
@ -42,6 +56,15 @@ TASK_LIST = ActionNode(
|
|||
example=["game.py", "main.py"],
|
||||
)
|
||||
|
||||
REFINED_TASK_LIST = ActionNode(
|
||||
key="Refined Task list",
|
||||
expected_type=List[str],
|
||||
instruction="Review and refine the combined task list after the merger of Legacy Content and Incremental Content, "
|
||||
"and consistent with Refined File List. Ensure that tasks are organized in a logical and prioritized order, "
|
||||
"considering dependencies for a streamlined and efficient development process. ",
|
||||
example=["new_feature.py", "utils", "game.py", "main.py"],
|
||||
)
|
||||
|
||||
FULL_API_SPEC = ActionNode(
|
||||
key="Full API spec",
|
||||
expected_type=str,
|
||||
|
|
@ -54,9 +77,19 @@ SHARED_KNOWLEDGE = ActionNode(
|
|||
key="Shared Knowledge",
|
||||
expected_type=str,
|
||||
instruction="Detail any shared knowledge, like common utility functions or configuration variables.",
|
||||
example="'game.py' contains functions shared across the project.",
|
||||
example="`game.py` contains functions shared across the project.",
|
||||
)
|
||||
|
||||
REFINED_SHARED_KNOWLEDGE = ActionNode(
|
||||
key="Refined Shared Knowledge",
|
||||
expected_type=str,
|
||||
instruction="Update and expand shared knowledge to reflect any new elements introduced. This includes common "
|
||||
"utility functions, configuration variables for team collaboration. Retain content that is not related to "
|
||||
"incremental development but important for consistency and clarity.",
|
||||
example="`new_module.py` enhances shared utility functions for improved code reusability and collaboration.",
|
||||
)
|
||||
|
||||
|
||||
ANYTHING_UNCLEAR_PM = ActionNode(
|
||||
key="Anything UNCLEAR",
|
||||
expected_type=str,
|
||||
|
|
@ -74,13 +107,25 @@ NODES = [
|
|||
ANYTHING_UNCLEAR_PM,
|
||||
]
|
||||
|
||||
REFINED_NODES = [
|
||||
REQUIRED_PYTHON_PACKAGES,
|
||||
REQUIRED_OTHER_LANGUAGE_PACKAGES,
|
||||
REFINED_LOGIC_ANALYSIS,
|
||||
REFINED_TASK_LIST,
|
||||
FULL_API_SPEC,
|
||||
REFINED_SHARED_KNOWLEDGE,
|
||||
ANYTHING_UNCLEAR_PM,
|
||||
]
|
||||
|
||||
PM_NODE = ActionNode.from_children("PM_NODE", NODES)
|
||||
REFINED_PM_NODE = ActionNode.from_children("REFINED_PM_NODE", REFINED_NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = PM_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
prompt = REFINED_PM_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from pathlib import Path
|
|||
import aiofiles
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import (
|
||||
AGGREGATION,
|
||||
COMPOSITION,
|
||||
|
|
@ -29,16 +29,16 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
|||
|
||||
|
||||
class RebuildClassView(Action):
|
||||
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
|
||||
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=Path(self.context))
|
||||
repo_parser = RepoParser(base_directory=Path(self.i_context))
|
||||
# use pylint
|
||||
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.context))
|
||||
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
|
||||
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
|
||||
# use ast
|
||||
direction, diff_path = self._diff_path(path_root=Path(self.context).resolve(), package_root=package_root)
|
||||
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for file_info in symbols:
|
||||
# Align to the same root directory in accordance with `class_views`.
|
||||
|
|
@ -48,9 +48,9 @@ class RebuildClassView(Action):
|
|||
await graph_db.save()
|
||||
|
||||
async def _create_mermaid_class_views(self, graph_db):
|
||||
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
|
||||
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = path / CONFIG.git_repo.workdir.name
|
||||
pathname = path / self.context.git_repo.workdir.name
|
||||
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
|
||||
content = "classDiagram\n"
|
||||
logger.debug(content)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from pathlib import Path
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread, list_files
|
||||
|
|
@ -21,8 +21,8 @@ from metagpt.utils.graph_repository import GraphKeyword
|
|||
|
||||
|
||||
class RebuildSequenceView(Action):
|
||||
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
|
||||
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
entries = await RebuildSequenceView._search_main_entry(graph_db)
|
||||
for entry in entries:
|
||||
|
|
@ -41,7 +41,9 @@ class RebuildSequenceView(Action):
|
|||
|
||||
async def _rebuild_sequence_view(self, entry, graph_db):
|
||||
filename = entry.subject.split(":", 1)[0]
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.context, pathname=filename)
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
return
|
||||
content = await aread(filename=src_filename, encoding="utf-8")
|
||||
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
|
||||
data = await self.llm.aask(
|
||||
|
|
|
|||
|
|
@ -8,10 +8,8 @@ from typing import Callable, Optional, Union
|
|||
from pydantic import Field, parse_obj_as
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
|
@ -81,7 +79,7 @@ class CollectLinks(Action):
|
|||
"""Action class to collect links from a search engine."""
|
||||
|
||||
name: str = "CollectLinks"
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
desc: str = "Collect links from a search engine."
|
||||
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
|
|
@ -129,8 +127,8 @@ class CollectLinks(Action):
|
|||
if len(remove) == 0:
|
||||
break
|
||||
|
||||
model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum())
|
||||
prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp)
|
||||
model_name = config.get_openai_llm().model
|
||||
prompt = reduce_message_length(gen_msg(), model_name, system_text, 4096)
|
||||
logger.debug(prompt)
|
||||
queries = await self._aask(prompt, [system_text])
|
||||
try:
|
||||
|
|
@ -177,19 +175,16 @@ class WebBrowseAndSummarize(Action):
|
|||
"""Action class to explore the web and provide summaries of articles and webpages."""
|
||||
|
||||
name: str = "WebBrowseAndSummarize"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
i_context: Optional[str] = None
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: Optional[WebBrowserEngine] = None
|
||||
web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if CONFIG.model_for_researcher_summary:
|
||||
self.llm.model = CONFIG.model_for_researcher_summary
|
||||
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
|
||||
run_func=self.browse_func,
|
||||
)
|
||||
|
||||
|
|
@ -220,9 +215,7 @@ class WebBrowseAndSummarize(Action):
|
|||
for u, content in zip([url, *urls], contents):
|
||||
content = content.inner_text
|
||||
chunk_summaries = []
|
||||
for prompt in generate_prompt_chunk(
|
||||
content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp
|
||||
):
|
||||
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096):
|
||||
logger.debug(prompt)
|
||||
summary = await self._aask(prompt, [system_text])
|
||||
if summary == "Not relevant.":
|
||||
|
|
@ -247,14 +240,8 @@ class WebBrowseAndSummarize(Action):
|
|||
class ConductResearch(Action):
|
||||
"""Action class to conduct research and generate a research report."""
|
||||
|
||||
name: str = "ConductResearch"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if CONFIG.model_for_researcher_report:
|
||||
self.llm.model = CONFIG.model_for_researcher_report
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@
|
|||
class.
|
||||
"""
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -48,7 +48,7 @@ WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION.
|
|||
You should fill in necessary instruction, status, send to, and finally return all content between the --- segment line.
|
||||
"""
|
||||
|
||||
CONTEXT = """
|
||||
TEMPLATE_CONTEXT = """
|
||||
## Development Code File Name
|
||||
{code_file_name}
|
||||
## Development Code
|
||||
|
|
@ -77,7 +77,7 @@ standard errors:
|
|||
|
||||
class RunCode(Action):
|
||||
name: str = "RunCode"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
|
||||
@classmethod
|
||||
async def run_text(cls, code) -> Tuple[str, str]:
|
||||
|
|
@ -89,13 +89,12 @@ class RunCode(Action):
|
|||
return "", str(e)
|
||||
return namespace.get("result", ""), ""
|
||||
|
||||
@classmethod
|
||||
async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
|
||||
async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
|
||||
working_directory = str(working_directory)
|
||||
additional_python_paths = [str(path) for path in additional_python_paths]
|
||||
|
||||
# Copy the current environment variables
|
||||
env = CONFIG.new_environ()
|
||||
env = self.context.new_environ()
|
||||
|
||||
# Modify the PYTHONPATH environment variable
|
||||
additional_python_paths = [working_directory] + additional_python_paths
|
||||
|
|
@ -119,25 +118,25 @@ class RunCode(Action):
|
|||
return stdout.decode("utf-8"), stderr.decode("utf-8")
|
||||
|
||||
async def run(self, *args, **kwargs) -> RunCodeResult:
|
||||
logger.info(f"Running {' '.join(self.context.command)}")
|
||||
if self.context.mode == "script":
|
||||
logger.info(f"Running {' '.join(self.i_context.command)}")
|
||||
if self.i_context.mode == "script":
|
||||
outs, errs = await self.run_script(
|
||||
command=self.context.command,
|
||||
working_directory=self.context.working_directory,
|
||||
additional_python_paths=self.context.additional_python_paths,
|
||||
command=self.i_context.command,
|
||||
working_directory=self.i_context.working_directory,
|
||||
additional_python_paths=self.i_context.additional_python_paths,
|
||||
)
|
||||
elif self.context.mode == "text":
|
||||
outs, errs = await self.run_text(code=self.context.code)
|
||||
elif self.i_context.mode == "text":
|
||||
outs, errs = await self.run_text(code=self.i_context.code)
|
||||
|
||||
logger.info(f"{outs=}")
|
||||
logger.info(f"{errs=}")
|
||||
|
||||
context = CONTEXT.format(
|
||||
code=self.context.code,
|
||||
code_file_name=self.context.code_filename,
|
||||
test_code=self.context.test_code,
|
||||
test_file_name=self.context.test_filename,
|
||||
command=" ".join(self.context.command),
|
||||
context = TEMPLATE_CONTEXT.format(
|
||||
code=self.i_context.code,
|
||||
code_file_name=self.i_context.code_filename,
|
||||
test_code=self.i_context.test_code,
|
||||
test_file_name=self.i_context.test_filename,
|
||||
command=" ".join(self.i_context.command),
|
||||
outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
|
||||
errs=errs[:10000], # truncate errors to avoid token overflow
|
||||
)
|
||||
|
|
@ -152,11 +151,23 @@ class RunCode(Action):
|
|||
return subprocess.run(cmd, check=check, cwd=cwd, env=env)
|
||||
|
||||
@staticmethod
|
||||
def _install_dependencies(working_directory, env):
|
||||
def _install_requirements(working_directory, env):
|
||||
file_path = Path(working_directory) / "requirements.txt"
|
||||
if not file_path.exists():
|
||||
return
|
||||
if file_path.stat().st_size == 0:
|
||||
return
|
||||
install_command = ["python", "-m", "pip", "install", "-r", "requirements.txt"]
|
||||
logger.info(" ".join(install_command))
|
||||
RunCode._install_via_subprocess(install_command, check=True, cwd=working_directory, env=env)
|
||||
|
||||
@staticmethod
|
||||
def _install_pytest(working_directory, env):
|
||||
install_pytest_command = ["python", "-m", "pip", "install", "pytest"]
|
||||
logger.info(" ".join(install_pytest_command))
|
||||
RunCode._install_via_subprocess(install_pytest_command, check=True, cwd=working_directory, env=env)
|
||||
|
||||
@staticmethod
|
||||
def _install_dependencies(working_directory, env):
|
||||
RunCode._install_requirements(working_directory, env)
|
||||
RunCode._install_pytest(working_directory, env)
|
||||
|
|
|
|||
|
|
@ -8,10 +8,9 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
|
@ -103,32 +102,25 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
"""
|
||||
|
||||
|
||||
# TOTEST
|
||||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
engine: Optional[SearchEngineType] = None
|
||||
search_func: Optional[Any] = None
|
||||
search_engine: SearchEngine = None
|
||||
result: str = ""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_engine_and_run_func(cls, values):
|
||||
engine = values.get("engine")
|
||||
search_func = values.get("search_func")
|
||||
config = Config()
|
||||
|
||||
if engine is None:
|
||||
engine = config.search_engine
|
||||
@model_validator(mode="after")
|
||||
def validate_engine_and_run_func(self):
|
||||
if self.engine is None:
|
||||
self.engine = self.config.search_engine
|
||||
try:
|
||||
search_engine = SearchEngine(engine=engine, run_func=search_func)
|
||||
search_engine = SearchEngine(engine=self.engine, run_func=self.search_func)
|
||||
except pydantic.ValidationError:
|
||||
search_engine = None
|
||||
|
||||
values["search_engine"] = search_engine
|
||||
return values
|
||||
self.search_engine = search_engine
|
||||
return self
|
||||
|
||||
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
|
||||
if self.search_engine is None:
|
||||
|
|
|
|||
|
|
@ -29,9 +29,7 @@ class ArgumentsParingAction(Action):
|
|||
|
||||
@property
|
||||
def prompt(self):
|
||||
prompt = "You are a function parser. You can convert spoken words into function parameters.\n"
|
||||
prompt += "\n---\n"
|
||||
prompt += f"{self.skill.name} function parameters description:\n"
|
||||
prompt = f"{self.skill.name} function parameters description:\n"
|
||||
for k, v in self.skill.arguments.items():
|
||||
prompt += f"parameter `{k}`: {v}\n"
|
||||
prompt += "\n---\n"
|
||||
|
|
@ -49,7 +47,10 @@ class ArgumentsParingAction(Action):
|
|||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
prompt = self.prompt
|
||||
rsp = await self.llm.aask(msg=prompt, system_msgs=[])
|
||||
rsp = await self.llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."],
|
||||
)
|
||||
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
|
||||
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
|
||||
self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self)
|
||||
|
|
|
|||
|
|
@ -11,11 +11,8 @@ from pydantic import Field
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
NOTICE
|
||||
|
|
@ -29,9 +26,9 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
{system_design}
|
||||
```
|
||||
-----
|
||||
# Tasks
|
||||
# Task
|
||||
```text
|
||||
{tasks}
|
||||
{task}
|
||||
```
|
||||
-----
|
||||
{code_blocks}
|
||||
|
|
@ -90,10 +87,9 @@ flowchart TB
|
|||
"""
|
||||
|
||||
|
||||
# TOTEST
|
||||
class SummarizeCode(Action):
|
||||
name: str = "SummarizeCode"
|
||||
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
|
||||
async def summarize_code(self, prompt):
|
||||
|
|
@ -101,20 +97,20 @@ class SummarizeCode(Action):
|
|||
return code_rsp
|
||||
|
||||
async def run(self):
|
||||
design_pathname = Path(self.context.design_filename)
|
||||
design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
task_pathname = Path(self.context.task_filename)
|
||||
task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
design_pathname = Path(self.i_context.design_filename)
|
||||
design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name)
|
||||
task_pathname = Path(self.i_context.task_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=task_pathname.name)
|
||||
src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs
|
||||
code_blocks = []
|
||||
for filename in self.context.codes_filenames:
|
||||
for filename in self.i_context.codes_filenames:
|
||||
code_doc = await src_file_repo.get(filename)
|
||||
code_block = f"```python\n{code_doc.content}\n```\n-----"
|
||||
code_blocks.append(code_block)
|
||||
format_example = FORMAT_EXAMPLE
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
system_design=design_doc.content,
|
||||
tasks=task_doc.content,
|
||||
task=task_doc.content,
|
||||
code_blocks="\n".join(code_blocks),
|
||||
format_example=format_example,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,25 +9,31 @@
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
# TOTEST
|
||||
class TalkAction(Action):
|
||||
context: str
|
||||
i_context: str
|
||||
history_summary: str = ""
|
||||
knowledge: str = ""
|
||||
rsp: Optional[Message] = None
|
||||
|
||||
@property
|
||||
def agent_description(self):
|
||||
return self.context.kwargs.agent_description
|
||||
|
||||
@property
|
||||
def language(self):
|
||||
return self.context.kwargs.language or config.language
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
prompt = ""
|
||||
if CONFIG.agent_description:
|
||||
if self.agent_description:
|
||||
prompt = (
|
||||
f"You are {CONFIG.agent_description}. Your responses should align with the role-play agreement, "
|
||||
f"You are {self.agent_description}. Your responses should align with the role-play agreement, "
|
||||
f"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
|
||||
f"decline to answer without revealing your AI nature to preserve the character's image.\n\n"
|
||||
)
|
||||
|
|
@ -36,10 +42,10 @@ class TalkAction(Action):
|
|||
prompt += (
|
||||
"If the information is insufficient, you can search in the historical conversation or knowledge above.\n"
|
||||
)
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
language = self.language
|
||||
prompt += (
|
||||
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n "
|
||||
f"{self.context}"
|
||||
f"{self.i_context}"
|
||||
)
|
||||
logger.debug(f"PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
|
@ -47,11 +53,11 @@ class TalkAction(Action):
|
|||
@property
|
||||
def prompt_gpt4(self):
|
||||
kvs = {
|
||||
"{role}": CONFIG.agent_description or "",
|
||||
"{role}": self.agent_description or "",
|
||||
"{history}": self.history_summary or "",
|
||||
"{knowledge}": self.knowledge or "",
|
||||
"{language}": CONFIG.language or DEFAULT_LANGUAGE,
|
||||
"{ask}": self.context,
|
||||
"{language}": self.language,
|
||||
"{ask}": self.i_context,
|
||||
}
|
||||
prompt = TalkActionPrompt.FORMATION_LOOSE
|
||||
for k, v in kvs.items():
|
||||
|
|
@ -68,9 +74,9 @@ class TalkAction(Action):
|
|||
|
||||
@property
|
||||
def aask_args(self):
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
language = self.language
|
||||
system_msgs = [
|
||||
f"You are {CONFIG.agent_description}.",
|
||||
f"You are {self.agent_description}.",
|
||||
"Your responses should align with the role-play agreement, "
|
||||
"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
|
||||
"decline to answer without revealing your AI nature to preserve the character's image.",
|
||||
|
|
@ -82,7 +88,7 @@ class TalkAction(Action):
|
|||
format_msgs.append({"role": "assistant", "content": self.knowledge})
|
||||
if self.history_summary:
|
||||
format_msgs.append({"role": "assistant", "content": self.history_summary})
|
||||
return self.context, format_msgs, system_msgs
|
||||
return self.i_context, format_msgs, system_msgs
|
||||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
msg, format_msgs, system_msgs = self.aask_args
|
||||
|
|
|
|||
|
|
@ -21,18 +21,17 @@ from pydantic import Field
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
|
||||
from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
DOCS_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
NOTICE
|
||||
|
|
@ -44,8 +43,8 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
## Design
|
||||
{design}
|
||||
|
||||
## Tasks
|
||||
{tasks}
|
||||
## Task
|
||||
{task}
|
||||
|
||||
## Legacy Code
|
||||
```Code
|
||||
|
|
@ -87,7 +86,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
|
||||
class WriteCode(Action):
|
||||
name: str = "WriteCode"
|
||||
context: Document = Field(default_factory=Document)
|
||||
i_context: Document = Field(default_factory=Document)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code(self, prompt) -> str:
|
||||
|
|
@ -96,16 +95,15 @@ class WriteCode(Action):
|
|||
return code
|
||||
|
||||
async def run(self, *args, **kwargs) -> CodingContext:
|
||||
bug_feedback = await FileRepository.get_file(filename=BUGFIX_FILENAME, relative_path=DOCS_FILE_REPO)
|
||||
coding_context = CodingContext.loads(self.context.content)
|
||||
test_doc = await FileRepository.get_file(
|
||||
filename="test_" + coding_context.filename + ".json", relative_path=TEST_OUTPUTS_FILE_REPO
|
||||
)
|
||||
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
|
||||
coding_context = CodingContext.loads(self.i_context.content)
|
||||
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
|
||||
code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
|
||||
code_plan_and_change = code_plan_and_change_doc.content if code_plan_and_change_doc else ""
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
summary_doc = None
|
||||
if coding_context.design_doc and coding_context.design_doc.filename:
|
||||
summary_doc = await FileRepository.get_file(
|
||||
filename=coding_context.design_doc.filename, relative_path=CODE_SUMMARIES_FILE_REPO
|
||||
)
|
||||
summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
|
||||
logs = ""
|
||||
if test_doc:
|
||||
test_detail = RunCodeResult.loads(test_doc.content)
|
||||
|
|
@ -113,42 +111,109 @@ class WriteCode(Action):
|
|||
|
||||
if bug_feedback:
|
||||
code_context = coding_context.code_doc.content
|
||||
elif code_plan_and_change:
|
||||
code_context = await self.get_codes(
|
||||
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True
|
||||
)
|
||||
else:
|
||||
code_context = await self.get_codes(coding_context.task_doc, exclude=self.context.filename)
|
||||
code_context = await self.get_codes(
|
||||
coding_context.task_doc,
|
||||
exclude=self.i_context.filename,
|
||||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
)
|
||||
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
design=coding_context.design_doc.content if coding_context.design_doc else "",
|
||||
tasks=coding_context.task_doc.content if coding_context.task_doc else "",
|
||||
code=code_context,
|
||||
logs=logs,
|
||||
feedback=bug_feedback.content if bug_feedback else "",
|
||||
filename=self.context.filename,
|
||||
summary_log=summary_doc.content if summary_doc else "",
|
||||
)
|
||||
if code_plan_and_change:
|
||||
prompt = REFINED_TEMPLATE.format(
|
||||
user_requirement=requirement_doc.content if requirement_doc else "",
|
||||
code_plan_and_change=code_plan_and_change,
|
||||
design=coding_context.design_doc.content if coding_context.design_doc else "",
|
||||
task=coding_context.task_doc.content if coding_context.task_doc else "",
|
||||
code=code_context,
|
||||
logs=logs,
|
||||
feedback=bug_feedback.content if bug_feedback else "",
|
||||
filename=self.i_context.filename,
|
||||
summary_log=summary_doc.content if summary_doc else "",
|
||||
)
|
||||
else:
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
design=coding_context.design_doc.content if coding_context.design_doc else "",
|
||||
task=coding_context.task_doc.content if coding_context.task_doc else "",
|
||||
code=code_context,
|
||||
logs=logs,
|
||||
feedback=bug_feedback.content if bug_feedback else "",
|
||||
filename=self.i_context.filename,
|
||||
summary_log=summary_doc.content if summary_doc else "",
|
||||
)
|
||||
logger.info(f"Writing {coding_context.filename}..")
|
||||
code = await self.write_code(prompt)
|
||||
if not coding_context.code_doc:
|
||||
# avoid root_path pydantic ValidationError if use WriteCode alone
|
||||
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
|
||||
root_path = self.context.src_workspace if self.context.src_workspace else ""
|
||||
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
|
||||
coding_context.code_doc.content = code
|
||||
return coding_context
|
||||
|
||||
@staticmethod
|
||||
async def get_codes(task_doc, exclude) -> str:
|
||||
async def get_codes(task_doc: Document, exclude: str, project_repo: ProjectRepo, use_inc: bool = False) -> str:
|
||||
"""
|
||||
Get codes for generating the exclude file in various scenarios.
|
||||
|
||||
Attributes:
|
||||
task_doc (Document): Document object of the task file.
|
||||
exclude (str): The file to be generated. Specifies the filename to be excluded from the code snippets.
|
||||
project_repo (ProjectRepo): ProjectRepo object of the project.
|
||||
use_inc (bool): Indicates whether the scenario involves incremental development. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: Codes for generating the exclude file.
|
||||
"""
|
||||
if not task_doc:
|
||||
return ""
|
||||
if not task_doc.content:
|
||||
task_doc.content = FileRepository.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO)
|
||||
task_doc = project_repo.docs.task.get(filename=task_doc.filename)
|
||||
m = json.loads(task_doc.content)
|
||||
code_filenames = m.get("Task list", [])
|
||||
code_filenames = m.get(TASK_LIST.key, []) if use_inc else m.get(REFINED_TASK_LIST.key, [])
|
||||
codes = []
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
for filename in code_filenames:
|
||||
if filename == exclude:
|
||||
continue
|
||||
doc = await src_file_repo.get(filename=filename)
|
||||
if not doc:
|
||||
continue
|
||||
codes.append(f"----- {filename}\n" + doc.content)
|
||||
src_file_repo = project_repo.srcs
|
||||
|
||||
# Incremental development scenario
|
||||
if use_inc:
|
||||
src_files = src_file_repo.all_files
|
||||
# Get the old workspace contained the old codes and old workspace are created in previous CodePlanAndChange
|
||||
old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace)
|
||||
old_files = old_file_repo.all_files
|
||||
# Get the union of the files in the src and old workspaces
|
||||
union_files_list = list(set(src_files) | set(old_files))
|
||||
for filename in union_files_list:
|
||||
# Exclude the current file from the all code snippets
|
||||
if filename == exclude:
|
||||
# If the file is in the old workspace, use the old code
|
||||
# Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and
|
||||
# essential functionality is included for the project’s requirements
|
||||
if filename in old_files and filename != "main.py":
|
||||
# Use old code
|
||||
doc = await old_file_repo.get(filename=filename)
|
||||
# If the file is in the src workspace, skip it
|
||||
else:
|
||||
continue
|
||||
codes.insert(0, f"-----Now, {filename} to be rewritten\n```{doc.content}```\n=====")
|
||||
# The code snippets are generated from the src workspace
|
||||
else:
|
||||
doc = await src_file_repo.get(filename=filename)
|
||||
# If the file does not exist in the src workspace, skip it
|
||||
if not doc:
|
||||
continue
|
||||
codes.append(f"----- {filename}\n```{doc.content}```")
|
||||
|
||||
# Normal scenario
|
||||
else:
|
||||
for filename in code_filenames:
|
||||
# Exclude the current file to get the code snippets for generating the current file
|
||||
if filename == exclude:
|
||||
continue
|
||||
doc = await src_file_repo.get(filename=filename)
|
||||
if not doc:
|
||||
continue
|
||||
codes.append(f"----- {filename}\n```{doc.content}```")
|
||||
|
||||
return "\n".join(codes)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
@File : write_review.py
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -21,16 +21,15 @@ REVIEW = ActionNode(
|
|||
],
|
||||
)
|
||||
|
||||
LGTM = ActionNode(
|
||||
key="LGTM",
|
||||
expected_type=str,
|
||||
instruction="LGTM/LBTM. If the code is fully implemented, "
|
||||
"give a LGTM (Looks Good To Me), otherwise provide a LBTM (Looks Bad To Me).",
|
||||
REVIEW_RESULT = ActionNode(
|
||||
key="ReviewResult",
|
||||
expected_type=Literal["LGTM", "LBTM"],
|
||||
instruction="LGTM/LBTM. If the code is fully implemented, " "give a LGTM, otherwise provide a LBTM.",
|
||||
example="LBTM",
|
||||
)
|
||||
|
||||
ACTIONS = ActionNode(
|
||||
key="Actions",
|
||||
NEXT_STEPS = ActionNode(
|
||||
key="NextSteps",
|
||||
expected_type=str,
|
||||
instruction="Based on the code review outcome, suggest actionable steps. This can include code changes, "
|
||||
"refactoring suggestions, or any follow-up tasks.",
|
||||
|
|
@ -69,7 +68,7 @@ WRITE_DRAFT = ActionNode(
|
|||
)
|
||||
|
||||
|
||||
WRITE_MOVE_FUNCTION = ActionNode(
|
||||
WRITE_FUNCTION = ActionNode(
|
||||
key="WriteFunction",
|
||||
expected_type=str,
|
||||
instruction="write code for the function not implemented.",
|
||||
|
|
@ -555,8 +554,8 @@ LBTM
|
|||
"""
|
||||
|
||||
|
||||
WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM, ACTIONS])
|
||||
WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_MOVE_FUNCTION])
|
||||
WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, REVIEW_RESULT, NEXT_STEPS])
|
||||
WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_FUNCTION])
|
||||
|
||||
|
||||
CR_FOR_MOVE_FUNCTION_BY_3 = """
|
||||
|
|
@ -579,8 +578,7 @@ class WriteCodeAN(Action):
|
|||
|
||||
async def run(self, context):
|
||||
self.llm.system_prompt = "You are an outstanding engineer and can implement any code"
|
||||
return await WRITE_MOVE_FUNCTION.fill(context=context, llm=self.llm, schema="json")
|
||||
# return await WRITE_CODE_NODE.fill(context=context, llm=self.llm, schema="markdown")
|
||||
return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
|
|||
210
metagpt/actions/write_code_plan_and_change_an.py
Normal file
210
metagpt/actions/write_code_plan_and_change_an.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mannaandpoem
|
||||
@File : write_code_plan_and_change_an.py
|
||||
"""
|
||||
import os
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.schema import CodePlanAndChangeContext
|
||||
|
||||
CODE_PLAN_AND_CHANGE = ActionNode(
|
||||
key="Code Plan And Change",
|
||||
expected_type=str,
|
||||
instruction="Developing comprehensive and step-by-step incremental development plan, and write Incremental "
|
||||
"Change by making a code draft that how to implement incremental development including detailed steps based on the "
|
||||
"context. Note: Track incremental changes using mark of '+' or '-' for add/modify/delete code, and conforms to the "
|
||||
"output format of git diff",
|
||||
example="""
|
||||
1. Plan for calculator.py: Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, multiplication, and division. Additionally, implement robust error handling for the division operation to mitigate potential issues related to division by zero.
|
||||
```python
|
||||
class Calculator:
|
||||
self.result = number1 + number2
|
||||
return self.result
|
||||
|
||||
- def sub(self, number1, number2) -> float:
|
||||
+ def subtract(self, number1: float, number2: float) -> float:
|
||||
+ '''
|
||||
+ Subtracts the second number from the first and returns the result.
|
||||
+
|
||||
+ Args:
|
||||
+ number1 (float): The number to be subtracted from.
|
||||
+ number2 (float): The number to subtract.
|
||||
+
|
||||
+ Returns:
|
||||
+ float: The difference of number1 and number2.
|
||||
+ '''
|
||||
+ self.result = number1 - number2
|
||||
+ return self.result
|
||||
+
|
||||
def multiply(self, number1: float, number2: float) -> float:
|
||||
- pass
|
||||
+ '''
|
||||
+ Multiplies two numbers and returns the result.
|
||||
+
|
||||
+ Args:
|
||||
+ number1 (float): The first number to multiply.
|
||||
+ number2 (float): The second number to multiply.
|
||||
+
|
||||
+ Returns:
|
||||
+ float: The product of number1 and number2.
|
||||
+ '''
|
||||
+ self.result = number1 * number2
|
||||
+ return self.result
|
||||
+
|
||||
def divide(self, number1: float, number2: float) -> float:
|
||||
- pass
|
||||
+ '''
|
||||
+ ValueError: If the second number is zero.
|
||||
+ '''
|
||||
+ if number2 == 0:
|
||||
+ raise ValueError('Cannot divide by zero')
|
||||
+ self.result = number1 / number2
|
||||
+ return self.result
|
||||
+
|
||||
- def reset_result(self):
|
||||
+ def clear(self):
|
||||
+ if self.result != 0.0:
|
||||
+ print("Result is not zero, clearing...")
|
||||
+ else:
|
||||
+ print("Result is already zero, no need to clear.")
|
||||
+
|
||||
self.result = 0.0
|
||||
```
|
||||
|
||||
2. Plan for main.py: Integrate new API endpoints for subtraction, multiplication, and division into the existing codebase of `main.py`. Then, ensure seamless integration with the overall application architecture and maintain consistency with coding standards.
|
||||
```python
|
||||
def add_numbers():
|
||||
result = calculator.add_numbers(num1, num2)
|
||||
return jsonify({'result': result}), 200
|
||||
|
||||
-# TODO: Implement subtraction, multiplication, and division operations
|
||||
+@app.route('/subtract_numbers', methods=['POST'])
|
||||
+def subtract_numbers():
|
||||
+ data = request.get_json()
|
||||
+ num1 = data.get('num1', 0)
|
||||
+ num2 = data.get('num2', 0)
|
||||
+ result = calculator.subtract_numbers(num1, num2)
|
||||
+ return jsonify({'result': result}), 200
|
||||
+
|
||||
+@app.route('/multiply_numbers', methods=['POST'])
|
||||
+def multiply_numbers():
|
||||
+ data = request.get_json()
|
||||
+ num1 = data.get('num1', 0)
|
||||
+ num2 = data.get('num2', 0)
|
||||
+ try:
|
||||
+ result = calculator.divide_numbers(num1, num2)
|
||||
+ except ValueError as e:
|
||||
+ return jsonify({'error': str(e)}), 400
|
||||
+ return jsonify({'result': result}), 200
|
||||
+
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
||||
```""",
|
||||
)
|
||||
|
||||
CODE_PLAN_AND_CHANGE_CONTEXT = """
|
||||
## User New Requirements
|
||||
{requirement}
|
||||
|
||||
## PRD
|
||||
{prd}
|
||||
|
||||
## Design
|
||||
{design}
|
||||
|
||||
## Task
|
||||
{task}
|
||||
|
||||
## Legacy Code
|
||||
{code}
|
||||
"""
|
||||
|
||||
REFINED_TEMPLATE = """
|
||||
NOTICE
|
||||
Role: You are a professional engineer; The main goal is to complete incremental development by combining legacy code and plan and Incremental Change, ensuring the integration of new features.
|
||||
|
||||
# Context
|
||||
## User New Requirements
|
||||
{user_requirement}
|
||||
|
||||
## Code Plan And Change
|
||||
{code_plan_and_change}
|
||||
|
||||
## Design
|
||||
{design}
|
||||
|
||||
## Task
|
||||
{task}
|
||||
|
||||
## Legacy Code
|
||||
```Code
|
||||
{code}
|
||||
```
|
||||
|
||||
## Debug logs
|
||||
```text
|
||||
{logs}
|
||||
|
||||
{summary_log}
|
||||
```
|
||||
|
||||
## Bug Feedback logs
|
||||
```text
|
||||
{feedback}
|
||||
```
|
||||
|
||||
# Format example
|
||||
## Code: {filename}
|
||||
```python
|
||||
## {filename}
|
||||
...
|
||||
```
|
||||
|
||||
# Instruction: Based on the context, follow "Format example", write or rewrite code.
|
||||
## Write/Rewrite Code: Only write one file {filename}, write or rewrite complete code using triple quotes based on the following attentions and context.
|
||||
1. Only One file: do your best to implement THIS ONLY ONE FILE.
|
||||
2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.
|
||||
3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import.
|
||||
4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.
|
||||
5. Follow Code Plan And Change: If there is any Incremental Change that is marked by the git diff format using '+' and '-' for add/modify/delete code, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the plan.
|
||||
6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
|
||||
7. Before using a external variable/module, make sure you import it first.
|
||||
8. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.
|
||||
9. Attention: Retain details that are not related to incremental development but are important for maintaining the consistency and clarity of the old code.
|
||||
"""
|
||||
|
||||
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", [CODE_PLAN_AND_CHANGE])
|
||||
|
||||
|
||||
class WriteCodePlanAndChange(Action):
|
||||
name: str = "WriteCodePlanAndChange"
|
||||
i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext)
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to "
|
||||
"meticulously craft comprehensive incremental development plan and deliver detailed incremental change"
|
||||
prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename)
|
||||
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
|
||||
code_text = await self.get_old_codes()
|
||||
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
|
||||
requirement=self.i_context.requirement,
|
||||
prd=prd_doc.content,
|
||||
design=design_doc.content,
|
||||
task=task_doc.content,
|
||||
code=code_text,
|
||||
)
|
||||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
||||
async def get_old_codes(self) -> str:
|
||||
self.repo.old_workspace = self.repo.git_repo.workdir / os.path.basename(self.config.project_path)
|
||||
old_file_repo = self.repo.git_repo.new_file_repository(relative_path=self.repo.old_workspace)
|
||||
old_codes = await old_file_repo.get_all()
|
||||
codes = [f"----- {code.filename}\n```{code.content}```" for code in old_codes]
|
||||
return "\n".join(codes)
|
||||
|
|
@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import CODE_PLAN_AND_CHANGE_FILENAME, REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -120,7 +120,7 @@ REWRITE_CODE_TEMPLATE = """
|
|||
|
||||
class WriteCodeReview(Action):
|
||||
name: str = "WriteCodeReview"
|
||||
context: CodingContext = Field(default_factory=CodingContext)
|
||||
i_context: CodingContext = Field(default_factory=CodingContext)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
|
||||
|
|
@ -136,41 +136,64 @@ class WriteCodeReview(Action):
|
|||
return result, code
|
||||
|
||||
async def run(self, *args, **kwargs) -> CodingContext:
|
||||
iterative_code = self.context.code_doc.content
|
||||
k = CONFIG.code_review_k_times or 1
|
||||
iterative_code = self.i_context.code_doc.content
|
||||
k = self.context.config.code_review_k_times or 1
|
||||
|
||||
for i in range(k):
|
||||
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
|
||||
task_content = self.context.task_doc.content if self.context.task_doc else ""
|
||||
code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename)
|
||||
context = "\n".join(
|
||||
[
|
||||
"## System Design\n" + str(self.context.design_doc) + "\n",
|
||||
"## Tasks\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
format_example = FORMAT_EXAMPLE.format(filename=self.i_context.code_doc.filename)
|
||||
task_content = self.i_context.task_doc.content if self.i_context.task_doc else ""
|
||||
code_context = await WriteCode.get_codes(
|
||||
self.i_context.task_doc,
|
||||
exclude=self.i_context.filename,
|
||||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
use_inc=self.config.inc,
|
||||
)
|
||||
|
||||
if not self.config.inc:
|
||||
context = "\n".join(
|
||||
[
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
)
|
||||
else:
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
code_plan_and_change_doc = await self.repo.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
|
||||
context = "\n".join(
|
||||
[
|
||||
"## User New Requirements\n" + str(requirement_doc) + "\n",
|
||||
"## Code Plan And Change\n" + str(code_plan_and_change_doc) + "\n",
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
)
|
||||
|
||||
context_prompt = PROMPT_TEMPLATE.format(
|
||||
context=context,
|
||||
code=iterative_code,
|
||||
filename=self.context.code_doc.filename,
|
||||
filename=self.i_context.code_doc.filename,
|
||||
)
|
||||
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
|
||||
format_example=format_example,
|
||||
)
|
||||
len1 = len(iterative_code) if iterative_code else 0
|
||||
len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0
|
||||
logger.info(
|
||||
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
|
||||
f"{len(self.context.code_doc.content)=}"
|
||||
f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
|
||||
f"len(self.i_context.code_doc.content)={len2}"
|
||||
)
|
||||
result, rewrited_code = await self.write_code_review_and_rewrite(
|
||||
context_prompt, cr_prompt, self.context.code_doc.filename
|
||||
context_prompt, cr_prompt, self.i_context.code_doc.filename
|
||||
)
|
||||
if "LBTM" in result:
|
||||
iterative_code = rewrited_code
|
||||
elif "LGTM" in result:
|
||||
self.context.code_doc.content = iterative_code
|
||||
return self.context
|
||||
self.i_context.code_doc.content = iterative_code
|
||||
return self.i_context
|
||||
# code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING)
|
||||
# self._save(context, filename, code)
|
||||
# 如果rewrited_code是None(原code perfect),那么直接返回code
|
||||
self.context.code_doc.content = iterative_code
|
||||
return self.context
|
||||
self.i_context.code_doc.content = iterative_code
|
||||
return self.i_context
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ class WriteDocstring(Action):
|
|||
"""
|
||||
|
||||
desc: str = "Write docstring for code."
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -14,26 +14,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.actions.write_prd_an import (
|
||||
COMPETITIVE_QUADRANT_CHART,
|
||||
PROJECT_NAME,
|
||||
REFINED_PRD_NODE,
|
||||
WP_IS_RELATIVE_NODE,
|
||||
WP_ISSUE_TYPE_NODE,
|
||||
WRITE_PRD_NODE,
|
||||
)
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
COMPETITIVE_ANALYSIS_FILE_REPO,
|
||||
DOCS_FILE_REPO,
|
||||
PRD_PDF_FILE_REPO,
|
||||
PRDS_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -63,135 +59,114 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
|
||||
class WritePRD(Action):
|
||||
name: str = "WritePRD"
|
||||
content: Optional[str] = None
|
||||
"""WritePRD deal with the following situations:
|
||||
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
|
||||
2. New requirement: If the requirement is a new requirement, the PRD document will be generated.
|
||||
3. Requirement update: If the requirement is an update, the PRD document will be updated.
|
||||
"""
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
# related to the PRD. If they are related, rewrite the PRD.
|
||||
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
|
||||
requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME)
|
||||
if requirement_doc and await self._is_bugfix(requirement_doc.content):
|
||||
await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
|
||||
await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
return Message(
|
||||
content=bug_fix.model_dump_json(),
|
||||
instruct_content=bug_fix,
|
||||
role="",
|
||||
cause_by=FixBug,
|
||||
sent_from=self,
|
||||
send_to="Alex", # the name of Engineer
|
||||
)
|
||||
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
|
||||
"""Run the action."""
|
||||
req: Document = await self.repo.requirement
|
||||
docs: list[Document] = await self.repo.docs.prd.get_all()
|
||||
if not req:
|
||||
raise FileNotFoundError("No requirement document found.")
|
||||
|
||||
if await self._is_bugfix(req.content):
|
||||
logger.info(f"Bugfix detected: {req.content}")
|
||||
return await self._handle_bugfix(req)
|
||||
# remove bugfix file from last round in case of conflict
|
||||
await self.repo.docs.delete(filename=BUGFIX_FILENAME)
|
||||
|
||||
# if requirement is related to other documents, update them, otherwise create a new one
|
||||
if related_docs := await self.get_related_docs(req, docs):
|
||||
logger.info(f"Requirement update detected: {req.content}")
|
||||
return await self._handle_requirement_update(req, related_docs)
|
||||
else:
|
||||
await docs_file_repo.delete(filename=BUGFIX_FILENAME)
|
||||
logger.info(f"New requirement detected: {req.content}")
|
||||
return await self._handle_new_requirement(req)
|
||||
|
||||
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
prd_docs = await prds_file_repo.get_all()
|
||||
change_files = Documents()
|
||||
for prd_doc in prd_docs:
|
||||
prd_doc = await self._update_prd(
|
||||
requirement_doc=requirement_doc, prd_doc=prd_doc, prds_file_repo=prds_file_repo, *args, **kwargs
|
||||
)
|
||||
if not prd_doc:
|
||||
continue
|
||||
change_files.docs[prd_doc.filename] = prd_doc
|
||||
logger.info(f"rewrite prd: {prd_doc.filename}")
|
||||
# If there is no existing PRD, generate one using 'docs/requirement.txt'.
|
||||
if not change_files.docs:
|
||||
prd_doc = await self._update_prd(
|
||||
requirement_doc=requirement_doc, prd_doc=None, prds_file_repo=prds_file_repo, *args, **kwargs
|
||||
)
|
||||
if prd_doc:
|
||||
change_files.docs[prd_doc.filename] = prd_doc
|
||||
logger.debug(f"new prd: {prd_doc.filename}")
|
||||
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
|
||||
# 'publish' message to transition the workflow to the next stage. This design allows room for global
|
||||
# optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
async def _handle_bugfix(self, req: Document) -> Message:
|
||||
# ... bugfix logic ...
|
||||
await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content)
|
||||
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
return Message(
|
||||
content=bug_fix.model_dump_json(),
|
||||
instruct_content=bug_fix,
|
||||
role="",
|
||||
cause_by=FixBug,
|
||||
sent_from=self,
|
||||
send_to="Alex", # the name of Engineer
|
||||
)
|
||||
|
||||
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
|
||||
# sas = SearchAndSummarize()
|
||||
# # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
|
||||
# rsp = ""
|
||||
# info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}"
|
||||
# if sas.result:
|
||||
# logger.info(sas.result)
|
||||
# logger.info(rsp)
|
||||
project_name = CONFIG.project_name or ""
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
|
||||
async def _handle_new_requirement(self, req: Document) -> ActionOutput:
|
||||
"""handle new requirement"""
|
||||
project_name = self.project_name
|
||||
context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
|
||||
await self._rename_workspace(node)
|
||||
return node
|
||||
new_prd_doc = await self.repo.docs.prd.save(
|
||||
filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json()
|
||||
)
|
||||
await self._save_competitive_analysis(new_prd_doc)
|
||||
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
|
||||
return Documents.from_iterable(documents=[new_prd_doc]).to_action_output()
|
||||
|
||||
async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool:
|
||||
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd_doc.content, requirements=new_requirement_doc.content)
|
||||
async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput:
|
||||
# ... requirement update logic ...
|
||||
for doc in related_docs:
|
||||
await self._update_prd(req, doc)
|
||||
return Documents.from_iterable(documents=related_docs).to_action_output()
|
||||
|
||||
async def _is_bugfix(self, context: str) -> bool:
|
||||
if not self.repo.code_files_exists():
|
||||
return False
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
|
||||
return node.get("issue_type") == "BUG"
|
||||
|
||||
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
|
||||
"""get the related documents"""
|
||||
# refine: use gather to speed up
|
||||
return [i for i in docs if await self._is_related(req, i)]
|
||||
|
||||
async def _is_related(self, req: Document, old_prd: Document) -> bool:
|
||||
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
|
||||
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
|
||||
return node.get("is_relative") == "YES"
|
||||
|
||||
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document:
|
||||
if not CONFIG.project_name:
|
||||
CONFIG.project_name = Path(CONFIG.project_path).name
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
|
||||
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
|
||||
prd_doc.content = node.instruct_content.model_dump_json()
|
||||
async def _merge(self, req: Document, related_doc: Document) -> Document:
|
||||
if not self.project_name:
|
||||
self.project_name = Path(self.project_path).name
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
|
||||
node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
|
||||
related_doc.content = node.instruct_content.model_dump_json()
|
||||
await self._rename_workspace(node)
|
||||
return prd_doc
|
||||
return related_doc
|
||||
|
||||
async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None:
|
||||
if not prd_doc:
|
||||
prd = await self._run_new_requirement(
|
||||
requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs
|
||||
)
|
||||
new_prd_doc = Document(
|
||||
root_path=PRDS_FILE_REPO,
|
||||
filename=FileRepository.new_filename() + ".json",
|
||||
content=prd.instruct_content.model_dump_json(),
|
||||
)
|
||||
elif await self._is_relative(requirement_doc, prd_doc):
|
||||
new_prd_doc = await self._merge(requirement_doc, prd_doc)
|
||||
else:
|
||||
return None
|
||||
await prds_file_repo.save(filename=new_prd_doc.filename, content=new_prd_doc.content)
|
||||
async def _update_prd(self, req: Document, prd_doc: Document) -> Document:
|
||||
new_prd_doc: Document = await self._merge(req, prd_doc)
|
||||
await self.repo.docs.prd.save_doc(doc=new_prd_doc)
|
||||
await self._save_competitive_analysis(new_prd_doc)
|
||||
await self._save_pdf(new_prd_doc)
|
||||
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
|
||||
return new_prd_doc
|
||||
|
||||
@staticmethod
|
||||
async def _save_competitive_analysis(prd_doc):
|
||||
async def _save_competitive_analysis(self, prd_doc: Document):
|
||||
m = json.loads(prd_doc.content)
|
||||
quadrant_chart = m.get("Competitive Quadrant Chart")
|
||||
quadrant_chart = m.get(COMPETITIVE_QUADRANT_CHART.key)
|
||||
if not quadrant_chart:
|
||||
return
|
||||
pathname = (
|
||||
CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
|
||||
)
|
||||
if not pathname.parent.exists():
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
await mermaid_to_file(quadrant_chart, pathname)
|
||||
pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
|
||||
|
||||
@staticmethod
|
||||
async def _save_pdf(prd_doc):
|
||||
await FileRepository.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO)
|
||||
|
||||
@staticmethod
|
||||
async def _rename_workspace(prd):
|
||||
if not CONFIG.project_name:
|
||||
async def _rename_workspace(self, prd):
|
||||
if not self.project_name:
|
||||
if isinstance(prd, (ActionOutput, ActionNode)):
|
||||
ws_name = prd.instruct_content.model_dump()["Project Name"]
|
||||
else:
|
||||
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
|
||||
if ws_name:
|
||||
CONFIG.project_name = ws_name
|
||||
if not CONFIG.project_name: # The LLM failed to provide a project name, and the user didn't provide one either.
|
||||
CONFIG.project_name = "app" + uuid.uuid4().hex[:16]
|
||||
CONFIG.git_repo.rename_root(CONFIG.project_name)
|
||||
|
||||
async def _is_bugfix(self, context) -> bool:
|
||||
src_workspace_path = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
|
||||
code_files = CONFIG.git_repo.get_files(relative_path=src_workspace_path)
|
||||
if not code_files:
|
||||
return False
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
|
||||
return node.get("issue_type") == "BUG"
|
||||
self.project_name = ws_name
|
||||
self.repo.git_repo.rename_root(self.project_name)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
|
||||
LANGUAGE = ActionNode(
|
||||
key="Language",
|
||||
|
|
@ -31,10 +30,18 @@ ORIGINAL_REQUIREMENTS = ActionNode(
|
|||
example="Create a 2048 game",
|
||||
)
|
||||
|
||||
REFINED_REQUIREMENTS = ActionNode(
|
||||
key="Refined Requirements",
|
||||
expected_type=str,
|
||||
instruction="Place the New user's original requirements here.",
|
||||
example="Create a 2048 game with a new feature that ...",
|
||||
)
|
||||
|
||||
PROJECT_NAME = ActionNode(
|
||||
key="Project Name",
|
||||
expected_type=str,
|
||||
instruction="According to the content of \"Original Requirements,\" name the project using snake case style , like 'game_2048' or 'simple_crm.",
|
||||
instruction='According to the content of "Original Requirements," name the project using snake case style , '
|
||||
"like 'game_2048' or 'simple_crm.",
|
||||
example="game_2048",
|
||||
)
|
||||
|
||||
|
|
@ -45,6 +52,18 @@ PRODUCT_GOALS = ActionNode(
|
|||
example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"],
|
||||
)
|
||||
|
||||
REFINED_PRODUCT_GOALS = ActionNode(
|
||||
key="Refined Product Goals",
|
||||
expected_type=List[str],
|
||||
instruction="Update and expand the original product goals to reflect the evolving needs due to incremental "
|
||||
"development.Ensure that the refined goals align with the current project direction and contribute to its success.",
|
||||
example=[
|
||||
"Enhance user engagement through new features",
|
||||
"Optimize performance for scalability",
|
||||
"Integrate innovative UI enhancements",
|
||||
],
|
||||
)
|
||||
|
||||
USER_STORIES = ActionNode(
|
||||
key="User Stories",
|
||||
expected_type=List[str],
|
||||
|
|
@ -58,6 +77,20 @@ USER_STORIES = ActionNode(
|
|||
],
|
||||
)
|
||||
|
||||
REFINED_USER_STORIES = ActionNode(
|
||||
key="Refined User Stories",
|
||||
expected_type=List[str],
|
||||
instruction="Update and expand the original scenario-based user stories to reflect the evolving needs due to "
|
||||
"incremental development. Ensure that the refined user stories capture incremental features and improvements. ",
|
||||
example=[
|
||||
"As a player, I want to choose difficulty levels to challenge my skills",
|
||||
"As a player, I want a visually appealing score display after each game for a better gaming experience",
|
||||
"As a player, I want a convenient restart button displayed when I lose to quickly start a new game",
|
||||
"As a player, I want an enhanced and aesthetically pleasing UI to elevate the overall gaming experience",
|
||||
"As a player, I want the ability to play the game seamlessly on my mobile phone for on-the-go entertainment",
|
||||
],
|
||||
)
|
||||
|
||||
COMPETITIVE_ANALYSIS = ActionNode(
|
||||
key="Competitive Analysis",
|
||||
expected_type=List[str],
|
||||
|
|
@ -97,6 +130,15 @@ REQUIREMENT_ANALYSIS = ActionNode(
|
|||
example="",
|
||||
)
|
||||
|
||||
REFINED_REQUIREMENT_ANALYSIS = ActionNode(
|
||||
key="Refined Requirement Analysis",
|
||||
expected_type=List[str],
|
||||
instruction="Review and refine the existing requirement analysis to align with the evolving needs of the project "
|
||||
"due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements "
|
||||
"required for the refined project scope.",
|
||||
example=["Require add/update/modify ..."],
|
||||
)
|
||||
|
||||
REQUIREMENT_POOL = ActionNode(
|
||||
key="Requirement Pool",
|
||||
expected_type=List[List[str]],
|
||||
|
|
@ -104,6 +146,14 @@ REQUIREMENT_POOL = ActionNode(
|
|||
example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]],
|
||||
)
|
||||
|
||||
REFINED_REQUIREMENT_POOL = ActionNode(
|
||||
key="Refined Requirement Pool",
|
||||
expected_type=List[List[str]],
|
||||
instruction="List down the top 5 to 7 requirements with their priority (P0, P1, P2). "
|
||||
"Cover both legacy content and incremental content. Retain content unrelated to incremental development",
|
||||
example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]],
|
||||
)
|
||||
|
||||
UI_DESIGN_DRAFT = ActionNode(
|
||||
key="UI Design draft",
|
||||
expected_type=str,
|
||||
|
|
@ -152,15 +202,22 @@ NODES = [
|
|||
ANYTHING_UNCLEAR,
|
||||
]
|
||||
|
||||
REFINED_NODES = [
|
||||
LANGUAGE,
|
||||
PROGRAMMING_LANGUAGE,
|
||||
REFINED_REQUIREMENTS,
|
||||
PROJECT_NAME,
|
||||
REFINED_PRODUCT_GOALS,
|
||||
REFINED_USER_STORIES,
|
||||
COMPETITIVE_ANALYSIS,
|
||||
COMPETITIVE_QUADRANT_CHART,
|
||||
REFINED_REQUIREMENT_ANALYSIS,
|
||||
REFINED_REQUIREMENT_POOL,
|
||||
UI_DESIGN_DRAFT,
|
||||
ANYTHING_UNCLEAR,
|
||||
]
|
||||
|
||||
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES)
|
||||
REFINED_PRD_NODE = ActionNode.from_children("RefinedPRD", REFINED_NODES)
|
||||
WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON])
|
||||
WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON])
|
||||
|
||||
|
||||
def main():
|
||||
prompt = WRITE_PRD_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.actions.action import Action
|
|||
|
||||
class WritePRDReview(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
|
||||
prd: Optional[str] = None
|
||||
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class WriteTeachingPlanPart(Action):
|
||||
"""Write Teaching Plan Part"""
|
||||
|
||||
context: Optional[str] = None
|
||||
i_context: Optional[str] = None
|
||||
topic: str = ""
|
||||
language: str = "Chinese"
|
||||
rsp: Optional[str] = None
|
||||
|
|
@ -24,7 +24,7 @@ class WriteTeachingPlanPart(Action):
|
|||
statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, [])
|
||||
statements = []
|
||||
for p in statement_patterns:
|
||||
s = self.format_value(p)
|
||||
s = self.format_value(p, context=self.context)
|
||||
statements.append(s)
|
||||
formatter = (
|
||||
TeachingPlanBlock.PROMPT_TITLE_TEMPLATE
|
||||
|
|
@ -35,7 +35,7 @@ class WriteTeachingPlanPart(Action):
|
|||
formation=TeachingPlanBlock.FORMATION,
|
||||
role=self.prefix,
|
||||
statements="\n".join(statements),
|
||||
lesson=self.context,
|
||||
lesson=self.i_context,
|
||||
topic=self.topic,
|
||||
language=self.language,
|
||||
)
|
||||
|
|
@ -68,20 +68,23 @@ class WriteTeachingPlanPart(Action):
|
|||
return self.topic
|
||||
|
||||
@staticmethod
|
||||
def format_value(value):
|
||||
def format_value(value, context: Context):
|
||||
"""Fill parameters inside `value` with `options`."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if "{" not in value:
|
||||
return value
|
||||
|
||||
merged_opts = CONFIG.options or {}
|
||||
options = context.config.model_dump()
|
||||
for k, v in context.kwargs:
|
||||
options[k] = v # None value is allowed to override and disable the value from config.
|
||||
opts = {k: v for k, v in options.items() if v is not None}
|
||||
try:
|
||||
return value.format(**merged_opts)
|
||||
return value.format(**opts)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Parameter is missing:{e}")
|
||||
|
||||
for k, v in merged_opts.items():
|
||||
for k, v in opts.items():
|
||||
value = value.replace("{" + f"{k}" + "}", str(v))
|
||||
return value
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ you should correctly import the necessary classes based on these file locations!
|
|||
|
||||
class WriteTest(Action):
|
||||
name: str = "WriteTest"
|
||||
context: Optional[TestingContext] = None
|
||||
i_context: Optional[TestingContext] = None
|
||||
|
||||
async def write_code(self, prompt):
|
||||
code_rsp = await self._aask(prompt)
|
||||
|
|
@ -55,16 +55,16 @@ class WriteTest(Action):
|
|||
return code
|
||||
|
||||
async def run(self, *args, **kwargs) -> TestingContext:
|
||||
if not self.context.test_doc:
|
||||
self.context.test_doc = Document(
|
||||
filename="test_" + self.context.code_doc.filename, root_path=TEST_CODES_FILE_REPO
|
||||
if not self.i_context.test_doc:
|
||||
self.i_context.test_doc = Document(
|
||||
filename="test_" + self.i_context.code_doc.filename, root_path=TEST_CODES_FILE_REPO
|
||||
)
|
||||
fake_root = "/data"
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
code_to_test=self.context.code_doc.content,
|
||||
test_file_name=self.context.test_doc.filename,
|
||||
source_file_path=fake_root + "/" + self.context.code_doc.root_relative_path,
|
||||
code_to_test=self.i_context.code_doc.content,
|
||||
test_file_name=self.i_context.test_doc.filename,
|
||||
source_file_path=fake_root + "/" + self.i_context.code_doc.root_relative_path,
|
||||
workspace=fake_root,
|
||||
)
|
||||
self.context.test_doc.content = await self.write_code(prompt)
|
||||
return self.context
|
||||
self.i_context.test_doc.content = await self.write_code(prompt)
|
||||
return self.i_context
|
||||
|
|
|
|||
|
|
@ -1,290 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Provide configuration, singleton
|
||||
@Modified By: mashenquan, 2023/11/27.
|
||||
1. According to Section 2.2.3.11 of RFC 135, add git repository support.
|
||||
2. Add the parameter `src_workspace` for the old version project path.
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.common import require_python_version
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
class NotConfiguredException(Exception):
|
||||
"""Exception raised for errors in the configuration.
|
||||
|
||||
Attributes:
|
||||
message -- explanation of the error
|
||||
"""
|
||||
|
||||
def __init__(self, message="The required configuration is not set"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class LLMProviderEnum(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""
|
||||
Regular usage method:
|
||||
config = Config("config.yaml")
|
||||
secret_key = config.get_key("MY_SECRET_KEY")
|
||||
print("Secret key:", secret_key)
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
home_yaml_file = Path.home() / ".metagpt/config.yaml"
|
||||
key_yaml_file = METAGPT_ROOT / "config/key.yaml"
|
||||
default_yaml_file = METAGPT_ROOT / "config/config.yaml"
|
||||
|
||||
def __init__(self, yaml_file=default_yaml_file, cost_data=""):
|
||||
global_options = OPTIONS.get()
|
||||
# cli paras
|
||||
self.project_path = ""
|
||||
self.project_name = ""
|
||||
self.inc = False
|
||||
self.reqa_file = ""
|
||||
self.max_auto_summarize_code = 0
|
||||
self.git_reinit = False
|
||||
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
|
||||
self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager()
|
||||
self._update()
|
||||
global_options.update(OPTIONS.get())
|
||||
logger.debug("Config loading done.")
|
||||
|
||||
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
|
||||
"""Get first valid LLM provider enum"""
|
||||
mappings = {
|
||||
LLMProviderEnum.OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
|
||||
),
|
||||
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
|
||||
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
|
||||
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
|
||||
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
|
||||
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
|
||||
LLMProviderEnum.METAGPT: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
|
||||
),
|
||||
LLMProviderEnum.AZURE_OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY)
|
||||
and self.OPENAI_API_TYPE == "azure"
|
||||
and self.DEPLOYMENT_NAME
|
||||
and self.OPENAI_API_VERSION
|
||||
),
|
||||
LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
|
||||
}
|
||||
provider = None
|
||||
for k, v in mappings.items():
|
||||
if v:
|
||||
provider = k
|
||||
break
|
||||
if provider is None:
|
||||
if self.DEFAULT_PROVIDER:
|
||||
provider = LLMProviderEnum(self.DEFAULT_PROVIDER)
|
||||
else:
|
||||
raise NotConfiguredException("You should config a LLM configuration first")
|
||||
|
||||
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
model_name = self.get_model_name(provider=provider)
|
||||
if model_name:
|
||||
logger.info(f"{provider} Model: {model_name}")
|
||||
if provider:
|
||||
logger.info(f"API: {provider}")
|
||||
return provider
|
||||
|
||||
def get_model_name(self, provider=None) -> str:
|
||||
provider = provider or self.get_default_llm_provider_enum()
|
||||
model_mappings = {
|
||||
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
|
||||
}
|
||||
return model_mappings.get(provider, "")
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_llm_key(k: str) -> bool:
|
||||
return bool(k and k != "YOUR_API_KEY")
|
||||
|
||||
def _update(self):
|
||||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
|
||||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = self._get("ANTHROPIC_API_KEY")
|
||||
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
|
||||
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
|
||||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
self.gemini_api_key = self._get("GEMINI_API_KEY")
|
||||
self.ollama_api_base = self._get("OLLAMA_API_BASE")
|
||||
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
|
||||
|
||||
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
|
||||
_ = self.get_default_llm_provider_enum()
|
||||
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
self.openai_api_type = self._get("OPENAI_API_TYPE")
|
||||
self.openai_api_version = self._get("OPENAI_API_VERSION")
|
||||
self.openai_api_rpm = self._get("RPM", 3)
|
||||
self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4-1106-preview")
|
||||
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)
|
||||
self.deployment_name = self._get("DEPLOYMENT_NAME", "gpt-4")
|
||||
|
||||
self.spark_appid = self._get("SPARK_APPID")
|
||||
self.spark_api_secret = self._get("SPARK_API_SECRET")
|
||||
self.spark_api_key = self._get("SPARK_API_KEY")
|
||||
self.domain = self._get("DOMAIN")
|
||||
self.spark_url = self._get("SPARK_URL")
|
||||
|
||||
self.fireworks_api_base = self._get("FIREWORKS_API_BASE")
|
||||
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
|
||||
|
||||
self.claude_api_key = self._get("ANTHROPIC_API_KEY")
|
||||
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
|
||||
self.serper_api_key = self._get("SERPER_API_KEY")
|
||||
self.google_api_key = self._get("GOOGLE_API_KEY")
|
||||
self.google_cse_id = self._get("GOOGLE_CSE_ID")
|
||||
self.search_engine = SearchEngineType(self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE))
|
||||
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", WebBrowserEngineType.PLAYWRIGHT))
|
||||
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
|
||||
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
|
||||
|
||||
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
|
||||
if self.long_term_memory:
|
||||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.code_review_k_times = 2
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
self.mmdc = self._get("MMDC", "mmdc")
|
||||
self.calc_usage = self._get("CALC_USAGE", True)
|
||||
self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
|
||||
self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
|
||||
self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs")
|
||||
self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "")
|
||||
|
||||
workspace_uid = (
|
||||
self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
|
||||
)
|
||||
self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False)
|
||||
self.prompt_schema = self._get("PROMPT_FORMAT", "json")
|
||||
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
|
||||
val = self._get("WORKSPACE_PATH_WITH_UID")
|
||||
if val and val.lower() == "true": # for agent
|
||||
self.workspace_path = self.workspace_path / workspace_uid
|
||||
self._ensure_workspace_exists()
|
||||
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
|
||||
self.timeout = int(self._get("TIMEOUT", 3))
|
||||
|
||||
self.kaggle_username = self._get("KAGGLE_USERNAME", "")
|
||||
self.kaggle_key = self._get("KAGGLE_KEY", "")
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
||||
# Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135.
|
||||
if project_path:
|
||||
inc = True
|
||||
project_name = project_name or Path(project_path).name
|
||||
self.project_path = project_path
|
||||
self.project_name = project_name
|
||||
self.inc = inc
|
||||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
def _ensure_workspace_exists(self):
|
||||
self.workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug(f"WORKSPACE_PATH set to {self.workspace_path}")
|
||||
|
||||
def _init_with_config_files_and_env(self, yaml_file):
|
||||
"""Load from config/key.yaml, config/config.yaml, and env in decreasing order of priority"""
|
||||
configs = dict(os.environ)
|
||||
|
||||
for _yaml_file in [yaml_file, self.key_yaml_file, self.home_yaml_file]:
|
||||
if not _yaml_file.exists():
|
||||
continue
|
||||
|
||||
# Load local YAML file
|
||||
with open(_yaml_file, "r", encoding="utf-8") as file:
|
||||
yaml_data = yaml.safe_load(file)
|
||||
if not yaml_data:
|
||||
continue
|
||||
configs.update(yaml_data)
|
||||
OPTIONS.set(configs)
|
||||
|
||||
@staticmethod
|
||||
def _get(*args, **kwargs):
|
||||
i = OPTIONS.get()
|
||||
return i.get(*args, **kwargs)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
|
||||
Throw an error if not found."""
|
||||
value = self._get(key, *args, **kwargs)
|
||||
if value is None:
|
||||
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")
|
||||
return value
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
OPTIONS.get()[name] = value
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
i = OPTIONS.get()
|
||||
return i.get(name)
|
||||
|
||||
def set_context(self, options: dict):
|
||||
"""Update current config"""
|
||||
if not options:
|
||||
return
|
||||
opts = deepcopy(OPTIONS.get())
|
||||
opts.update(options)
|
||||
OPTIONS.set(opts)
|
||||
self._update()
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
"""Return all key-values"""
|
||||
return OPTIONS.get()
|
||||
|
||||
def new_environ(self):
|
||||
"""Return a new os.environ object"""
|
||||
env = os.environ.copy()
|
||||
i = self.options
|
||||
env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
|
||||
CONFIG = Config()
|
||||
143
metagpt/config2.py
Normal file
143
metagpt/config2.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 01:25
|
||||
@Author : alexanderwu
|
||||
@File : config2.py
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
from metagpt.configs.s3_config import S3Config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.configs.workspace_config import WorkspaceConfig
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class CLIParams(BaseModel):
|
||||
"""CLI parameters"""
|
||||
|
||||
project_path: str = ""
|
||||
project_name: str = ""
|
||||
inc: bool = False
|
||||
reqa_file: str = ""
|
||||
max_auto_summarize_code: int = 0
|
||||
git_reinit: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_project_path(self):
|
||||
"""Check project_path and project_name"""
|
||||
if self.project_path:
|
||||
self.inc = True
|
||||
self.project_name = self.project_name or Path(self.project_path).name
|
||||
return self
|
||||
|
||||
|
||||
class Config(CLIParams, YamlModel):
|
||||
"""Configurations for MetaGPT"""
|
||||
|
||||
# Key Parameters
|
||||
llm: LLMConfig
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
# Tool Parameters
|
||||
search: Optional[SearchConfig] = None
|
||||
browser: BrowserConfig = BrowserConfig()
|
||||
mermaid: MermaidConfig = MermaidConfig()
|
||||
|
||||
# Storage Parameters
|
||||
s3: Optional[S3Config] = None
|
||||
redis: Optional[RedisConfig] = None
|
||||
|
||||
# Misc Parameters
|
||||
repair_llm_output: bool = False
|
||||
prompt_schema: Literal["json", "markdown", "raw"] = "json"
|
||||
workspace: WorkspaceConfig = WorkspaceConfig()
|
||||
enable_longterm_memory: bool = False
|
||||
code_review_k_times: int = 2
|
||||
|
||||
# Will be removed in the future
|
||||
llm_for_researcher_summary: str = "gpt3"
|
||||
llm_for_researcher_report: str = "gpt3"
|
||||
METAGPT_TEXT_TO_IMAGE_MODEL_URL: str = ""
|
||||
language: str = "English"
|
||||
redis_key: str = "placeholder"
|
||||
mmdc: str = "mmdc"
|
||||
puppeteer_config: str = ""
|
||||
pyppeteer_executable_path: str = ""
|
||||
IFLYTEK_APP_ID: str = ""
|
||||
IFLYTEK_API_SECRET: str = ""
|
||||
IFLYTEK_API_KEY: str = ""
|
||||
AZURE_TTS_SUBSCRIPTION_KEY: str = ""
|
||||
AZURE_TTS_REGION: str = ""
|
||||
mermaid_engine: str = "nodejs"
|
||||
|
||||
@classmethod
|
||||
def from_home(cls, path):
|
||||
"""Load config from ~/.metagpt/config.yaml"""
|
||||
pathname = CONFIG_ROOT / path
|
||||
if not pathname.exists():
|
||||
return None
|
||||
return Config.from_yaml_file(pathname)
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
"""Load default config
|
||||
- Priority: env < default_config_paths
|
||||
- Inside default_config_paths, the latter one overwrites the former one
|
||||
"""
|
||||
default_config_paths: List[Path] = [
|
||||
METAGPT_ROOT / "config/config2.yaml",
|
||||
Path.home() / ".metagpt/config2.yaml",
|
||||
]
|
||||
|
||||
dicts = [dict(os.environ)]
|
||||
dicts += [Config.read_yaml(path) for path in default_config_paths]
|
||||
final = merge_dict(dicts)
|
||||
return Config(**final)
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
||||
# Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135.
|
||||
if project_path:
|
||||
inc = True
|
||||
project_name = project_name or Path(project_path).name
|
||||
self.project_path = project_path
|
||||
self.project_name = project_name
|
||||
self.inc = inc
|
||||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
if self.llm.api_type == LLMType.OPENAI:
|
||||
return self.llm
|
||||
return None
|
||||
|
||||
def get_azure_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get Azure LLMConfig by name. If no Azure, raise Exception"""
|
||||
if self.llm.api_type == LLMType.AZURE:
|
||||
return self.llm
|
||||
return None
|
||||
|
||||
|
||||
def merge_dict(dicts: Iterable[Dict]) -> Dict:
|
||||
"""Merge multiple dicts into one, with the latter dict overwriting the former"""
|
||||
result = {}
|
||||
for dictionary in dicts:
|
||||
result.update(dictionary)
|
||||
return result
|
||||
|
||||
|
||||
config = Config.default()
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/11 14:44
|
||||
@Time : 2024/1/4 16:33
|
||||
@Author : alexanderwu
|
||||
@File : test_action.py
|
||||
@File : __init__.py
|
||||
"""
|
||||
20
metagpt/configs/browser_config.py
Normal file
20
metagpt/configs/browser_config.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:06
|
||||
@Author : alexanderwu
|
||||
@File : browser_config.py
|
||||
"""
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class BrowserConfig(YamlModel):
|
||||
"""Config for Browser"""
|
||||
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
|
||||
browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
|
||||
driver: Literal["chromium", "firefox", "webkit"] = "chromium"
|
||||
path: str = ""
|
||||
78
metagpt/configs/llm_config.py
Normal file
78
metagpt/configs/llm_config.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 16:33
|
||||
@Author : alexanderwu
|
||||
@File : llm_config.py
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class LLMType(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
||||
|
||||
class LLMConfig(YamlModel):
|
||||
"""Config for LLM
|
||||
|
||||
OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681
|
||||
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
app_id: Optional[str] = None
|
||||
api_secret: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
|
||||
# For Chat Completion
|
||||
max_token: int = 4096
|
||||
temperature: float = 0.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = 0
|
||||
repetition_penalty: float = 1.0
|
||||
stop: Optional[str] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
best_of: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
stream: bool = False
|
||||
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
|
||||
top_logprobs: Optional[int] = None
|
||||
timeout: int = 60
|
||||
|
||||
# For Network
|
||||
proxy: Optional[str] = None
|
||||
|
||||
# Cost Control
|
||||
calc_usage: bool = True
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
if v in ["", None, "YOUR_API_KEY"]:
|
||||
raise ValueError("Please set your API key in config.yaml")
|
||||
return v
|
||||
18
metagpt/configs/mermaid_config.py
Normal file
18
metagpt/configs/mermaid_config.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:07
|
||||
@Author : alexanderwu
|
||||
@File : mermaid_config.py
|
||||
"""
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class MermaidConfig(YamlModel):
|
||||
"""Config for Mermaid"""
|
||||
|
||||
engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs"
|
||||
path: str = ""
|
||||
puppeteer_config: str = "" # Only for nodejs engine
|
||||
26
metagpt/configs/redis_config.py
Normal file
26
metagpt/configs/redis_config.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:06
|
||||
@Author : alexanderwu
|
||||
@File : redis_config.py
|
||||
"""
|
||||
from metagpt.utils.yaml_model import YamlModelWithoutDefault
|
||||
|
||||
|
||||
class RedisConfig(YamlModelWithoutDefault):
|
||||
host: str
|
||||
port: int
|
||||
username: str = ""
|
||||
password: str
|
||||
db: str
|
||||
|
||||
def to_url(self):
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
def to_kwargs(self):
|
||||
return {
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"db": self.db,
|
||||
}
|
||||
15
metagpt/configs/s3_config.py
Normal file
15
metagpt/configs/s3_config.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:07
|
||||
@Author : alexanderwu
|
||||
@File : s3_config.py
|
||||
"""
|
||||
from metagpt.utils.yaml_model import YamlModelWithoutDefault
|
||||
|
||||
|
||||
class S3Config(YamlModelWithoutDefault):
|
||||
access_key: str
|
||||
secret_key: str
|
||||
endpoint: str
|
||||
bucket: str
|
||||
17
metagpt/configs/search_config.py
Normal file
17
metagpt/configs/search_config.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:06
|
||||
@Author : alexanderwu
|
||||
@File : search_config.py
|
||||
"""
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class SearchConfig(YamlModel):
|
||||
"""Config for Search"""
|
||||
|
||||
api_key: str
|
||||
api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
|
||||
cse_id: str = "" # for google
|
||||
38
metagpt/configs/workspace_config.py
Normal file
38
metagpt/configs/workspace_config.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 19:09
|
||||
@Author : alexanderwu
|
||||
@File : workspace_config.py
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import field_validator, model_validator
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class WorkspaceConfig(YamlModel):
|
||||
path: Path = DEFAULT_WORKSPACE_ROOT
|
||||
use_uid: bool = False
|
||||
uid: str = ""
|
||||
|
||||
@field_validator("path")
|
||||
@classmethod
|
||||
def check_workspace_path(cls, v):
|
||||
if isinstance(v, str):
|
||||
v = Path(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_uid_and_update_path(self):
|
||||
if self.use_uid and not self.uid:
|
||||
self.uid = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
|
||||
self.path = self.path / self.uid
|
||||
|
||||
# Create workspace path if not exists
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
return self
|
||||
|
|
@ -9,7 +9,6 @@
|
|||
@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135.
|
||||
@Modified By: mashenquan, 2023/12/5. Add directories for code summarization..
|
||||
"""
|
||||
import contextvars
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -17,8 +16,6 @@ from loguru import logger
|
|||
|
||||
import metagpt
|
||||
|
||||
OPTIONS = contextvars.ContextVar("OPTIONS", default={})
|
||||
|
||||
|
||||
def get_metagpt_package_root():
|
||||
"""Get the root directory of the installed package."""
|
||||
|
|
@ -47,7 +44,7 @@ def get_metagpt_root():
|
|||
|
||||
|
||||
# METAGPT PROJECT ROOT AND VARS
|
||||
|
||||
CONFIG_ROOT = Path.home() / ".metagpt"
|
||||
METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
|
||||
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
||||
|
||||
|
|
@ -73,12 +70,10 @@ SKILL_DIRECTORY = SOURCE_ROOT / "skills"
|
|||
TOOL_SCHEMA_PATH = METAGPT_ROOT / "metagpt/tools/schemas"
|
||||
TOOL_LIBS_PATH = METAGPT_ROOT / "metagpt/tools/libs"
|
||||
|
||||
|
||||
# REAL CONSTS
|
||||
|
||||
MEM_TTL = 24 * 30 * 3600
|
||||
|
||||
|
||||
MESSAGE_ROUTE_FROM = "sent_from"
|
||||
MESSAGE_ROUTE_TO = "send_to"
|
||||
MESSAGE_ROUTE_CAUSE_BY = "cause_by"
|
||||
|
|
@ -89,25 +84,28 @@ MESSAGE_ROUTE_TO_NONE = "<none>"
|
|||
REQUIREMENT_FILENAME = "requirement.txt"
|
||||
BUGFIX_FILENAME = "bugfix.txt"
|
||||
PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt"
|
||||
CODE_PLAN_AND_CHANGE_FILENAME = "code_plan_and_change.json"
|
||||
|
||||
DOCS_FILE_REPO = "docs"
|
||||
PRDS_FILE_REPO = "docs/prds"
|
||||
PRDS_FILE_REPO = "docs/prd"
|
||||
SYSTEM_DESIGN_FILE_REPO = "docs/system_design"
|
||||
TASK_FILE_REPO = "docs/tasks"
|
||||
TASK_FILE_REPO = "docs/task"
|
||||
CODE_PLAN_AND_CHANGE_FILE_REPO = "docs/code_plan_and_change"
|
||||
COMPETITIVE_ANALYSIS_FILE_REPO = "resources/competitive_analysis"
|
||||
DATA_API_DESIGN_FILE_REPO = "resources/data_api_design"
|
||||
SEQ_FLOW_FILE_REPO = "resources/seq_flow"
|
||||
SYSTEM_DESIGN_PDF_FILE_REPO = "resources/system_design"
|
||||
PRD_PDF_FILE_REPO = "resources/prd"
|
||||
TASK_PDF_FILE_REPO = "resources/api_spec_and_tasks"
|
||||
TASK_PDF_FILE_REPO = "resources/api_spec_and_task"
|
||||
CODE_PLAN_AND_CHANGE_PDF_FILE_REPO = "resources/code_plan_and_change"
|
||||
TEST_CODES_FILE_REPO = "tests"
|
||||
TEST_OUTPUTS_FILE_REPO = "test_outputs"
|
||||
CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
|
||||
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
|
||||
CODE_SUMMARIES_FILE_REPO = "docs/code_summary"
|
||||
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
|
||||
RESOURCES_FILE_REPO = "resources"
|
||||
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
|
||||
SD_OUTPUT_FILE_REPO = "resources/sd_output"
|
||||
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
|
||||
CLASS_VIEW_FILE_REPO = "docs/class_views"
|
||||
CLASS_VIEW_FILE_REPO = "docs/class_view"
|
||||
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
|
||||
|
|
|
|||
97
metagpt/context.py
Normal file
97
metagpt/context.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4 16:32
|
||||
@Author : alexanderwu
|
||||
@File : context.py
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class AttrDict(BaseModel):
|
||||
"""A dict-like object that allows access to keys as attributes, compatible with Pydantic."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self.__dict__.get(key, None)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key in self.__dict__:
|
||||
del self.__dict__[key]
|
||||
else:
|
||||
raise AttributeError(f"No such attribute: {key}")
|
||||
|
||||
def set(self, key, val: Any):
|
||||
self.__dict__[key] = val
|
||||
|
||||
def get(self, key, default: Any = None):
|
||||
return self.__dict__.get(key, default)
|
||||
|
||||
def remove(self, key):
|
||||
if key in self.__dict__:
|
||||
self.__delattr__(key)
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
"""Env context for MetaGPT"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
kwargs: AttrDict = AttrDict()
|
||||
config: Config = Config.default()
|
||||
|
||||
repo: Optional[ProjectRepo] = None
|
||||
git_repo: Optional[GitRepository] = None
|
||||
src_workspace: Optional[Path] = None
|
||||
cost_manager: CostManager = CostManager()
|
||||
|
||||
_llm: Optional[BaseLLM] = None
|
||||
|
||||
def new_environ(self):
|
||||
"""Return a new os.environ object"""
|
||||
env = os.environ.copy()
|
||||
# i = self.options
|
||||
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
# """Use a LLM instance"""
|
||||
# self._llm_config = self.config.get_llm_config(name, provider)
|
||||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Return a LLM instance, fixme: support cache"""
|
||||
# if self._llm is None:
|
||||
self._llm = create_llm_instance(self.config.llm)
|
||||
if self._llm.cost_manager is None:
|
||||
self._llm.cost_manager = self.cost_manager
|
||||
return self._llm
|
||||
|
||||
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
|
||||
"""Return a LLM instance, fixme: support cache"""
|
||||
# if self._llm is None:
|
||||
llm = create_llm_instance(llm_config)
|
||||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self.cost_manager
|
||||
return llm
|
||||
102
metagpt/context_mixin.py
Normal file
102
metagpt/context_mixin.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/11 17:25
|
||||
@Author : alexanderwu
|
||||
@File : context_mixin.py
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.context import Context
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class ContextMixin(BaseModel):
|
||||
"""Mixin class for context and config"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# Pydantic has bug on _private_attr when using inheritance, so we use private_* instead
|
||||
# - https://github.com/pydantic/pydantic/issues/7142
|
||||
# - https://github.com/pydantic/pydantic/issues/7083
|
||||
# - https://github.com/pydantic/pydantic/issues/7091
|
||||
|
||||
# Env/Role/Action will use this context as private context, or use self.context as public context
|
||||
private_context: Optional[Context] = Field(default=None, exclude=True)
|
||||
# Env/Role/Action will use this config as private config, or use self.context.config as public config
|
||||
private_config: Optional[Config] = Field(default=None, exclude=True)
|
||||
|
||||
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
|
||||
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Optional[Context] = None,
|
||||
config: Optional[Config] = None,
|
||||
llm: Optional[BaseLLM] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize with config"""
|
||||
super().__init__(**kwargs)
|
||||
self.set_context(context)
|
||||
self.set_config(config)
|
||||
self.set_llm(llm)
|
||||
|
||||
def set(self, k, v, override=False):
|
||||
"""Set attribute"""
|
||||
if override or not self.__dict__.get(k):
|
||||
self.__dict__[k] = v
|
||||
|
||||
def set_context(self, context: Context, override=True):
|
||||
"""Set context"""
|
||||
self.set("private_context", context, override)
|
||||
|
||||
def set_config(self, config: Config, override=False):
|
||||
"""Set config"""
|
||||
self.set("private_config", config, override)
|
||||
if config is not None:
|
||||
_ = self.llm # init llm
|
||||
|
||||
def set_llm(self, llm: BaseLLM, override=False):
|
||||
"""Set llm"""
|
||||
self.set("private_llm", llm, override)
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
"""Role config: role config > context config"""
|
||||
if self.private_config:
|
||||
return self.private_config
|
||||
return self.context.config
|
||||
|
||||
@config.setter
|
||||
def config(self, config: Config) -> None:
|
||||
"""Set config"""
|
||||
self.set_config(config)
|
||||
|
||||
@property
|
||||
def context(self) -> Context:
|
||||
"""Role context: role context > context"""
|
||||
if self.private_context:
|
||||
return self.private_context
|
||||
return Context()
|
||||
|
||||
@context.setter
|
||||
def context(self, context: Context) -> None:
|
||||
"""Set context"""
|
||||
self.set_context(context)
|
||||
|
||||
@property
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Role llm: if not existed, init from role.config"""
|
||||
# print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}")
|
||||
if not self.private_llm:
|
||||
self.private_llm = self.context.llm_with_cost_manager_from_llm_config(self.config.llm)
|
||||
return self.private_llm
|
||||
|
||||
@llm.setter
|
||||
def llm(self, llm: BaseLLM) -> None:
|
||||
"""Set llm"""
|
||||
self.private_llm = llm
|
||||
|
|
@ -8,8 +8,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.config import Config
|
||||
|
||||
|
||||
class BaseStore(ABC):
|
||||
"""FIXME: consider add_index, set_index and think about granularity."""
|
||||
|
|
@ -31,7 +29,6 @@ class LocalStore(BaseStore, ABC):
|
|||
def __init__(self, raw_data_path: Path, cache_dir: Path = None):
|
||||
if not raw_data_path:
|
||||
raise FileNotFoundError
|
||||
self.config = Config()
|
||||
self.raw_data_path = raw_data_path
|
||||
self.fname = self.raw_data_path.stem
|
||||
if not cache_dir:
|
||||
|
|
|
|||
|
|
@ -9,14 +9,13 @@ import asyncio
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
||||
class FaissStore(LocalStore):
|
||||
|
|
@ -25,9 +24,7 @@ class FaissStore(LocalStore):
|
|||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
self.embedding = embedding or OpenAIEmbeddings(
|
||||
openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url
|
||||
)
|
||||
self.embedding = embedding or get_embedding()
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
|
|
|
|||
|
|
@ -12,16 +12,15 @@
|
|||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
|
||||
from metagpt.utils.common import is_send_to
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
|
|
@ -33,58 +32,22 @@ class Environment(BaseModel):
|
|||
|
||||
desc: str = Field(default="") # 环境描述
|
||||
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
|
||||
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
|
||||
member_addrs: dict[Role, Set] = Field(default_factory=dict, exclude=True)
|
||||
history: str = "" # For debug
|
||||
context: Context = Field(default_factory=Context, exclude=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
self.add_roles(self.roles.values())
|
||||
return self
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = []
|
||||
for role_key, role in self.roles.items():
|
||||
roles_info.append(
|
||||
{
|
||||
"role_class": role.__class__.__name__,
|
||||
"module_name": role.__module__,
|
||||
"role_name": role.name,
|
||||
"role_sub_tags": list(self.members.get(role)),
|
||||
}
|
||||
)
|
||||
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
|
||||
write_json_file(roles_path, roles_info)
|
||||
|
||||
history_path = stg_path.joinpath("history.json")
|
||||
write_json_file(history_path, {"content": self.history})
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Environment":
|
||||
"""stg_path: ./storage/team/environment/"""
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = read_json_file(roles_path)
|
||||
roles = []
|
||||
for role_info in roles_info:
|
||||
# role stored in ./environment/roles/{role_class}_{role_name}
|
||||
role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}")
|
||||
role = Role.deserialize(role_path)
|
||||
roles.append(role)
|
||||
|
||||
history = read_json_file(stg_path.joinpath("history.json"))
|
||||
history = history.get("content")
|
||||
|
||||
environment = Environment(**{"history": history})
|
||||
environment.add_roles(roles)
|
||||
|
||||
return environment
|
||||
|
||||
def add_role(self, role: Role):
|
||||
"""增加一个在当前环境的角色
|
||||
Add a role in the current environment
|
||||
"""
|
||||
self.roles[role.profile] = role
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def add_roles(self, roles: Iterable[Role]):
|
||||
"""增加一批在当前环境的角色
|
||||
|
|
@ -95,6 +58,7 @@ class Environment(BaseModel):
|
|||
|
||||
for role in roles: # setup system message with roles
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
"""
|
||||
|
|
@ -108,8 +72,8 @@ class Environment(BaseModel):
|
|||
logger.debug(f"publish_message: {message.dump()}")
|
||||
found = False
|
||||
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
for role, subscription in self.members.items():
|
||||
if is_subscribed(message, subscription):
|
||||
for role, addrs in self.member_addrs.items():
|
||||
if is_send_to(message, addrs):
|
||||
role.put_message(message)
|
||||
found = True
|
||||
if not found:
|
||||
|
|
@ -154,15 +118,14 @@ class Environment(BaseModel):
|
|||
return False
|
||||
return True
|
||||
|
||||
def get_subscription(self, obj):
|
||||
"""Get the labels for messages to be consumed by the object."""
|
||||
return self.members.get(obj, {})
|
||||
def get_addresses(self, obj):
|
||||
"""Get the addresses of the object."""
|
||||
return self.member_addrs.get(obj, {})
|
||||
|
||||
def set_subscription(self, obj, tags):
|
||||
"""Set the labels for message to be consumed by the object"""
|
||||
self.members[obj] = tags
|
||||
def set_addresses(self, obj, addresses):
|
||||
"""Set the addresses of the object"""
|
||||
self.member_addrs[obj] = addresses
|
||||
|
||||
@staticmethod
|
||||
def archive(auto_archive=True):
|
||||
if auto_archive and CONFIG.git_repo:
|
||||
CONFIG.git_repo.archive()
|
||||
def archive(self, auto_archive=True):
|
||||
if auto_archive and self.context.git_repo:
|
||||
self.context.git_repo.archive()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import aiofiles
|
|||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.context import Context
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
|
|
@ -73,14 +73,15 @@ class SkillsDeclaration(BaseModel):
|
|||
skill_data = yaml.safe_load(data)
|
||||
return SkillsDeclaration(**skill_data)
|
||||
|
||||
def get_skill_list(self, entity_name: str = "Assistant") -> Dict:
|
||||
def get_skill_list(self, entity_name: str = "Assistant", context: Context = None) -> Dict:
|
||||
"""Return the skill name based on the skill description."""
|
||||
entity = self.entities.get(entity_name)
|
||||
if not entity:
|
||||
return {}
|
||||
|
||||
# List of skills that the agent chooses to activate.
|
||||
agent_skills = CONFIG.agent_skills
|
||||
ctx = context or Context()
|
||||
agent_skills = ctx.kwargs.agent_skills
|
||||
if not agent_skills:
|
||||
return {}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,19 +6,19 @@
|
|||
@File : text_to_embedding.py
|
||||
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
|
||||
"""
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
import metagpt.config2
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
|
||||
|
||||
|
||||
async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs):
|
||||
async def text_to_embedding(text, model="text-embedding-ada-002", config: Config = metagpt.config2.config):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
if CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)
|
||||
raise EnvironmentError
|
||||
openai_api_key = config.get_openai_llm().api_key
|
||||
proxy = config.get_openai_llm().proxy
|
||||
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key, proxy=proxy)
|
||||
|
|
|
|||
|
|
@ -8,33 +8,37 @@
|
|||
"""
|
||||
import base64
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
import metagpt.config2
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
|
||||
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
|
||||
async def text_to_image(text, size_type: str = "512x512", config: Config = metagpt.config2.config):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
|
||||
:param model_url: MetaGPT model url
|
||||
:param config: Config
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
image_declaration = "data:image/png;base64,"
|
||||
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
|
||||
|
||||
model_url = config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
if model_url:
|
||||
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type)
|
||||
elif config.get_openai_llm():
|
||||
llm = LLM(llm_config=config.get_openai_llm())
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type, llm=llm)
|
||||
else:
|
||||
raise ValueError("Missing necessary parameters.")
|
||||
base64_data = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
s3 = S3(config.s3)
|
||||
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT)
|
||||
if url:
|
||||
return f""
|
||||
return image_declaration + base64_data if base64_data else ""
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@
|
|||
@File : text_to_speech.py
|
||||
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
|
||||
"""
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
import metagpt.config2
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.azure_tts import oas3_azsure_tts
|
||||
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
|
||||
|
|
@ -20,12 +20,7 @@ async def text_to_speech(
|
|||
voice="zh-CN-XiaomoNeural",
|
||||
style="affectionate",
|
||||
role="Girl",
|
||||
subscription_key="",
|
||||
region="",
|
||||
iflytek_app_id="",
|
||||
iflytek_api_key="",
|
||||
iflytek_api_secret="",
|
||||
**kwargs,
|
||||
config: Config = metagpt.config2.config,
|
||||
):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
|
@ -44,23 +39,27 @@ async def text_to_speech(
|
|||
|
||||
"""
|
||||
|
||||
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
|
||||
subscription_key = config.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
region = config.AZURE_TTS_REGION
|
||||
if subscription_key and region:
|
||||
audio_declaration = "data:audio/wav;base64,"
|
||||
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
s3 = S3(config.s3)
|
||||
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT)
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
if (CONFIG.IFLYTEK_APP_ID and CONFIG.IFLYTEK_API_KEY and CONFIG.IFLYTEK_API_SECRET) or (
|
||||
iflytek_app_id and iflytek_api_key and iflytek_api_secret
|
||||
):
|
||||
|
||||
iflytek_app_id = config.IFLYTEK_APP_ID
|
||||
iflytek_api_key = config.IFLYTEK_API_KEY
|
||||
iflytek_api_secret = config.IFLYTEK_API_SECRET
|
||||
if iflytek_app_id and iflytek_api_key and iflytek_api_secret:
|
||||
audio_declaration = "data:audio/mp3;base64,"
|
||||
base64_data = await oas3_iflytek_tts(
|
||||
text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret
|
||||
)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
s3 = S3(config.s3)
|
||||
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT)
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
|
|
|||
|
|
@ -5,20 +5,16 @@
|
|||
@Author : alexanderwu
|
||||
@File : llm.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.context import Context
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
|
||||
|
||||
_ = HumanProvider() # Avoid pre-commit error
|
||||
|
||||
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
if provider is None:
|
||||
provider = CONFIG.get_default_llm_provider_enum()
|
||||
|
||||
return LLM_REGISTRY.get_provider(provider)
|
||||
def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> BaseLLM:
|
||||
"""get the default llm provider if name is None"""
|
||||
ctx = context or Context()
|
||||
if llm_config is not None:
|
||||
ctx.llm_with_cost_manager_from_llm_config(llm_config)
|
||||
return ctx.llm()
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from typing import Dict, List, Optional
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import MetaGPTLLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
|
@ -29,9 +29,9 @@ class BrainMemory(BaseModel):
|
|||
historical_summary: str = ""
|
||||
last_history_id: str = ""
|
||||
is_dirty: bool = False
|
||||
last_talk: str = None
|
||||
last_talk: Optional[str] = None
|
||||
cacheable: bool = True
|
||||
llm: Optional[BaseLLM] = None
|
||||
llm: Optional[BaseLLM] = Field(default=None, exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
@ -56,8 +56,8 @@ class BrainMemory(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
async def loads(redis_key: str) -> "BrainMemory":
|
||||
redis = Redis()
|
||||
if not redis.is_valid or not redis_key:
|
||||
redis = Redis(config.redis)
|
||||
if not redis_key:
|
||||
return BrainMemory()
|
||||
v = await redis.get(key=redis_key)
|
||||
logger.debug(f"REDIS GET {redis_key} {v}")
|
||||
|
|
@ -70,8 +70,8 @@ class BrainMemory(BaseModel):
|
|||
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60):
|
||||
if not self.is_dirty:
|
||||
return
|
||||
redis = Redis()
|
||||
if not redis.is_valid or not redis_key:
|
||||
redis = Redis(config.redis)
|
||||
if not redis_key:
|
||||
return False
|
||||
v = self.model_dump_json()
|
||||
if self.cacheable:
|
||||
|
|
@ -83,7 +83,7 @@ class BrainMemory(BaseModel):
|
|||
def to_redis_key(prefix: str, user_id: str, chat_id: str):
|
||||
return f"{prefix}:{user_id}:{chat_id}"
|
||||
|
||||
async def set_history_summary(self, history_summary, redis_key, redis_conf):
|
||||
async def set_history_summary(self, history_summary, redis_key):
|
||||
if self.historical_summary == history_summary:
|
||||
if self.is_dirty:
|
||||
await self.dumps(redis_key=redis_key)
|
||||
|
|
@ -140,7 +140,7 @@ class BrainMemory(BaseModel):
|
|||
return text
|
||||
summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
if summary:
|
||||
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
|
||||
await self.set_history_summary(history_summary=summary, redis_key=config.redis_key)
|
||||
return summary
|
||||
raise ValueError(f"text too long:{text_length}")
|
||||
|
||||
|
|
@ -164,7 +164,7 @@ class BrainMemory(BaseModel):
|
|||
msgs.reverse()
|
||||
self.history = msgs
|
||||
self.is_dirty = True
|
||||
await self.dumps(redis_key=CONFIG.REDIS_KEY)
|
||||
await self.dumps(redis_key=config.redis.key)
|
||||
self.is_dirty = False
|
||||
|
||||
return BrainMemory.to_metagpt_history_format(self.history)
|
||||
|
|
@ -181,7 +181,7 @@ class BrainMemory(BaseModel):
|
|||
|
||||
summary = await self.summarize(llm=llm, max_words=500)
|
||||
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
language = config.language
|
||||
command = f"Translate the above summary into a {language} title of less than {max_words} words."
|
||||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Desc : the implement of Long-term memory
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
|
|
|||
|
|
@ -7,19 +7,13 @@
|
|||
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import DefaultDict, Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field, SerializeAsAny
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import (
|
||||
any_to_str,
|
||||
any_to_str_set,
|
||||
read_json_file,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
|
|
@ -29,22 +23,6 @@ class Memory(BaseModel):
|
|||
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
|
||||
ignore_id: bool = False
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
storage = self.model_dump()
|
||||
write_json_file(memory_path, storage)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Memory":
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
|
||||
memory_dict = read_json_file(memory_path)
|
||||
memory = Memory(**memory_dict)
|
||||
|
||||
return memory
|
||||
|
||||
def add(self, message: Message):
|
||||
"""Add a new message to storage, while updating the index"""
|
||||
if self.ignore_id:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Desc : the implement of memory storage
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from metagpt.provider.openai_api import OpenAILLM
|
|||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
|
|
@ -24,4 +26,6 @@ __all__ = [
|
|||
"AzureOpenAILLM",
|
||||
"MetaGPTLLM",
|
||||
"OllamaLLM",
|
||||
"HumanProvider",
|
||||
"SparkLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,12 +9,15 @@
|
|||
import anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
|
||||
|
||||
class Claude2:
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
client = Anthropic(api_key=self.config.api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
model="claude-2",
|
||||
|
|
@ -24,7 +27,7 @@ class Claude2:
|
|||
return res.completion
|
||||
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
|
||||
aclient = AsyncAnthropic(api_key=self.config.api_key)
|
||||
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
|
@ -13,12 +11,12 @@
|
|||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.AZURE_OPENAI)
|
||||
@register_provider(LLMType.AZURE)
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
|
|
@ -28,13 +26,13 @@ class AzureOpenAILLM(OpenAILLM):
|
|||
kwargs = self._make_client_kwargs()
|
||||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.aclient = AsyncAzureOpenAI(**kwargs)
|
||||
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(
|
||||
api_key=self.config.OPENAI_API_KEY,
|
||||
api_version=self.config.OPENAI_API_VERSION,
|
||||
azure_endpoint=self.config.OPENAI_BASE_URL,
|
||||
api_key=self.config.api_key,
|
||||
api_version=self.config.api_version,
|
||||
azure_endpoint=self.config.base_url,
|
||||
)
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
|
|
|
|||
|
|
@ -8,15 +8,32 @@
|
|||
"""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
config: LLMConfig
|
||||
use_system_prompt: bool = True
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
# OpenAI / Azure / Others
|
||||
aclient: Optional[Union[AsyncOpenAI]] = None
|
||||
cost_manager: Optional[CostManager] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "content": msg}
|
||||
|
||||
|
|
@ -43,10 +60,13 @@ class BaseLLM(ABC):
|
|||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs)
|
||||
else:
|
||||
message = [self._default_system_msg()] if self.use_system_prompt else []
|
||||
message = [self._default_system_msg()]
|
||||
if not self.use_system_prompt:
|
||||
message = []
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
||||
|
|
@ -63,10 +83,9 @@ class BaseLLM(ABC):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3) -> dict:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
|
|
@ -87,6 +106,10 @@ class BaseLLM(ABC):
|
|||
"""Required to provide the first text of choice"""
|
||||
return rsp.get("choices")[0]["message"]["content"]
|
||||
|
||||
def get_choice_delta_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of stream choice"""
|
||||
return rsp.get("choices")[0]["delta"]["content"]
|
||||
|
||||
def get_choice_function(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function of choice
|
||||
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
|
|
@ -64,44 +64,35 @@ class FireworksCostManager(CostManager):
|
|||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${max_budget:.3f} | "
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
@register_provider(LLMType.FIREWORKS)
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_fireworks()
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config=config)
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = FireworksCostManager()
|
||||
|
||||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._init_client()
|
||||
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
self.cost_manager = FireworksCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
|
||||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
# use FireworksCostManager not context.cost_manager
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,10 +79,8 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
|
||||
if stream and (
|
||||
"text/event-stream" in result.headers.get("Content-Type", "")
|
||||
or "application/x-ndjson" in result.headers.get("Content-Type", "")
|
||||
):
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
|
||||
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
|
||||
return (
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -41,21 +41,22 @@ class GeminiGenerativeModel(GenerativeModel):
|
|||
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.GEMINI)
|
||||
@register_provider(LLMType.GEMINI)
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.use_system_prompt = False # google gemini has no system prompt when use api
|
||||
|
||||
self.__init_gemini(CONFIG)
|
||||
self.__init_gemini(config)
|
||||
self.config = config
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: CONFIG):
|
||||
genai.configure(api_key=config.gemini_api_key)
|
||||
def __init_gemini(self, config: LLMConfig):
|
||||
genai.configure(api_key=config.api_key)
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
# Not to change BaseLLM default functions but update with Gemini's conversation format.
|
||||
|
|
@ -71,11 +72,11 @@ class GeminiLLM(BaseLLM):
|
|||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
|
|
@ -108,7 +109,7 @@ class GeminiLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Author: garylin2099
|
|||
"""
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
|
@ -14,6 +15,9 @@ class HumanProvider(BaseLLM):
|
|||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
rsp = input(msg)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
@Author : alexanderwu
|
||||
@File : llm_provider_registry.py
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class LLMProviderRegistry:
|
||||
|
|
@ -15,13 +16,9 @@ class LLMProviderRegistry:
|
|||
def register(self, key, provider_cls):
|
||||
self.providers[key] = provider_cls
|
||||
|
||||
def get_provider(self, enum: LLMProviderEnum):
|
||||
def get_provider(self, enum: LLMType):
|
||||
"""get provider instance according to the enum"""
|
||||
return self.providers[enum]()
|
||||
|
||||
|
||||
# Registry instance
|
||||
LLM_REGISTRY = LLMProviderRegistry()
|
||||
return self.providers[enum]
|
||||
|
||||
|
||||
def register_provider(key):
|
||||
|
|
@ -32,3 +29,12 @@ def register_provider(key):
|
|||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def create_llm_instance(config: LLMConfig) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
return LLM_REGISTRY.get_provider(config.api_type)(config)
|
||||
|
||||
|
||||
# Registry instance
|
||||
LLM_REGISTRY = LLMProviderRegistry()
|
||||
|
|
|
|||
|
|
@ -5,12 +5,11 @@
|
|||
@File : metagpt_api.py
|
||||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.METAGPT)
|
||||
@register_provider(LLMType.METAGPT)
|
||||
class MetaGPTLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,48 +13,34 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
||||
class OllamaCostManager(CostManager):
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Max budget: ${max_budget:.3f} | "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OLLAMA)
|
||||
@register_provider(LLMType.OLLAMA)
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_ollama(CONFIG)
|
||||
self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base)
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.__init_ollama(config)
|
||||
self.client = GeneralAPIRequestor(base_url=config.base_url)
|
||||
self.config = config
|
||||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self._cost_manager = OllamaCostManager()
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def __init_ollama(self, config: CONFIG):
|
||||
assert config.ollama_api_base
|
||||
self.model = config.ollama_api_model
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.base_url, "ollama base url is required!"
|
||||
self.model = config.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
|
|
@ -62,7 +48,7 @@ class OllamaLLM(BaseLLM):
|
|||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
|
|
|
|||
|
|
@ -4,56 +4,27 @@
|
|||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.cost_manager import Costs, TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
class OpenLLMCostManager(CostManager):
|
||||
"""open llm model is self-host, it's free and without cost"""
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Max budget: ${max_budget:.3f} | reference "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
@register_provider(LLMType.OPEN_LLM)
|
||||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
|
||||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._init_client()
|
||||
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not CONFIG.calc_usage:
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -3,15 +3,13 @@
|
|||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import AsyncIterator, Union
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
|
@ -25,14 +23,14 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -52,18 +50,19 @@ See FAQ 5.8
|
|||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPENAI)
|
||||
@register_provider(LLMType.OPENAI)
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self._init_openai()
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def _init_openai(self):
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
|
|
@ -71,7 +70,7 @@ class OpenAILLM(BaseLLM):
|
|||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
|
||||
kwargs = {"api_key": self.config.api_key, "base_url": self.config.base_url}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
if proxy_params := self._get_proxy_params():
|
||||
|
|
@ -81,10 +80,10 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def _get_proxy_params(self) -> dict:
|
||||
params = {}
|
||||
if self.config.openai_proxy:
|
||||
params = {"proxies": self.config.openai_proxy}
|
||||
if self.config.openai_base_url:
|
||||
params["base_url"] = self.config.openai_base_url
|
||||
if self.config.proxy:
|
||||
params = {"proxies": self.config.proxy}
|
||||
if self.config.base_url:
|
||||
params["base_url"] = self.config.base_url
|
||||
|
||||
return params
|
||||
|
||||
|
|
@ -105,7 +104,7 @@ class OpenAILLM(BaseLLM):
|
|||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
"timeout": max(CONFIG.timeout, timeout),
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
}
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
|
|
@ -266,7 +265,7 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not CONFIG.calc_usage:
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
|
|
@ -279,18 +278,28 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage and usage:
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
return self.config.max_token
|
||||
return get_max_completion_tokens(messages, self.model, self.config.max_tokens)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
"""Moderate content."""
|
||||
return await self.aclient.moderations.create(input=content)
|
||||
|
||||
async def atext_to_speech(self, **kwargs):
|
||||
"""text to speech"""
|
||||
return await self.aclient.audio.speech.create(**kwargs)
|
||||
|
||||
async def aspeech_to_text(self, **kwargs):
|
||||
"""speech to text"""
|
||||
return await self.aclient.audio.transcriptions.create(**kwargs)
|
||||
|
|
|
|||
|
|
@ -16,29 +16,30 @@ from wsgiref.handlers import format_date_time
|
|||
|
||||
import websocket # 使用websocket_client
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
@register_provider(LLMType.SPARK)
|
||||
class SparkLLM(BaseLLM):
|
||||
def __init__(self):
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
logger.warning("SparkLLM:当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
logger.error("该功能禁用。")
|
||||
w = GetMessageFromWeb(messages)
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages)
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
|
||||
|
|
@ -89,14 +90,14 @@ class GetMessageFromWeb:
|
|||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def __init__(self, text):
|
||||
def __init__(self, text, config: LLMConfig):
|
||||
self.text = text
|
||||
self.ret = ""
|
||||
self.spark_appid = CONFIG.spark_appid
|
||||
self.spark_api_secret = CONFIG.spark_api_secret
|
||||
self.spark_api_key = CONFIG.spark_api_key
|
||||
self.domain = CONFIG.domain
|
||||
self.spark_url = CONFIG.spark_url
|
||||
self.spark_appid = config.app_id
|
||||
self.spark_api_secret = config.api_secret
|
||||
self.spark_api_key = config.api_key
|
||||
self.domain = config.domain
|
||||
self.spark_url = config.base_url
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
|
|
|
|||
|
|
@ -1,75 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : async_sse_client to make keep the use of Event to access response
|
||||
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
|
||||
# refs to `zhipuai/core/_sse_client.py`
|
||||
|
||||
from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
|
||||
import json
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
class AsyncSSEClient(SSEClient):
|
||||
async def _aread(self):
|
||||
data = b""
|
||||
class AsyncSSEClient(object):
|
||||
def __init__(self, event_source: Iterator[Any]):
|
||||
self._event_source = event_source
|
||||
|
||||
async def stream(self) -> dict:
|
||||
if isinstance(self._event_source, bytes):
|
||||
raise RuntimeError(
|
||||
f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
|
||||
)
|
||||
async for chunk in self._event_source:
|
||||
for line in chunk.splitlines(True):
|
||||
data += line
|
||||
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
||||
yield data
|
||||
data = b""
|
||||
if data:
|
||||
yield data
|
||||
line = chunk.decode("utf-8")
|
||||
if line.startswith(":") or not line:
|
||||
return
|
||||
|
||||
async def async_events(self):
|
||||
async for chunk in self._aread():
|
||||
event = Event()
|
||||
# Split before decoding so splitlines() only uses \r and \n
|
||||
for line in chunk.splitlines():
|
||||
# Decode the line.
|
||||
line = line.decode(self._char_enc)
|
||||
|
||||
# Lines starting with a separator are comments and are to be
|
||||
# ignored.
|
||||
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
|
||||
continue
|
||||
|
||||
data = line.split(_FIELD_SEPARATOR, 1)
|
||||
field = data[0]
|
||||
|
||||
# Ignore unknown fields.
|
||||
if field not in event.__dict__:
|
||||
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
|
||||
continue
|
||||
|
||||
if len(data) > 1:
|
||||
# From the spec:
|
||||
# "If value starts with a single U+0020 SPACE character,
|
||||
# remove it from value."
|
||||
if data[1].startswith(" "):
|
||||
value = data[1][1:]
|
||||
else:
|
||||
value = data[1]
|
||||
else:
|
||||
# If no value is present after the separator,
|
||||
# assume an empty value.
|
||||
value = ""
|
||||
|
||||
# The data field may come over multiple lines and their values
|
||||
# are concatenated with each other.
|
||||
if field == "data":
|
||||
event.__dict__[field] += value + "\n"
|
||||
else:
|
||||
event.__dict__[field] = value
|
||||
|
||||
# Events with no data are not dispatched.
|
||||
if not event.data:
|
||||
continue
|
||||
|
||||
# If the data field ends with a newline, remove it.
|
||||
if event.data.endswith("\n"):
|
||||
event.data = event.data[0:-1]
|
||||
|
||||
# Empty event names default to 'message'
|
||||
event.event = event.event or "message"
|
||||
|
||||
# Dispatch the event
|
||||
self._logger.debug("Dispatching %s...", event)
|
||||
yield event
|
||||
field, _p, value = line.partition(":")
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if field == "data":
|
||||
if value.startswith("[DONE]"):
|
||||
break
|
||||
data = json.loads(value)
|
||||
yield data
|
||||
|
|
|
|||
|
|
@ -4,46 +4,27 @@
|
|||
|
||||
import json
|
||||
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType, ModelAPI
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT
|
||||
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
|
||||
|
||||
class ZhiPuModelAPI(ModelAPI):
|
||||
@classmethod
|
||||
def get_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
zhipuai_default_headers.update({"Authorization": token})
|
||||
return zhipuai_default_headers
|
||||
|
||||
@classmethod
|
||||
def get_sse_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
headers = {"Authorization": token}
|
||||
return headers
|
||||
|
||||
@classmethod
|
||||
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
|
||||
class ZhiPuModelAPI(ZhipuAI):
|
||||
def split_zhipu_api_url(self):
|
||||
# use this method to prevent zhipu api upgrading to different version.
|
||||
# and follow the GeneralAPIRequestor implemented based on openai sdk
|
||||
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
|
||||
"""
|
||||
example:
|
||||
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
|
||||
"""
|
||||
zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||
arr = zhipu_api_url.split("/api/")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions")
|
||||
return f"{arr[0]}/api", f"/{arr[1]}"
|
||||
|
||||
@classmethod
|
||||
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
|
||||
async def arequest(self, stream: bool, method: str, headers: dict, kwargs):
|
||||
# TODO to make the async request to be more generic for models in http mode.
|
||||
assert method in ["post", "get"]
|
||||
|
||||
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
|
||||
base_url, url = self.split_zhipu_api_url()
|
||||
requester = GeneralAPIRequestor(base_url=base_url)
|
||||
result, _, api_key = await requester.arequest(
|
||||
method=method,
|
||||
|
|
@ -51,25 +32,23 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
headers=headers,
|
||||
stream=stream,
|
||||
params=kwargs,
|
||||
request_timeout=zhipuai.api_timeout_seconds,
|
||||
request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read,
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def ainvoke(cls, **kwargs) -> dict:
|
||||
async def acreate(self, **kwargs) -> dict:
|
||||
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
|
||||
headers = cls.get_header()
|
||||
resp = await cls.arequest(
|
||||
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
|
||||
)
|
||||
headers = self._default_headers
|
||||
resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs)
|
||||
resp = resp.decode("utf-8")
|
||||
resp = json.loads(resp)
|
||||
if "error" in resp:
|
||||
raise RuntimeError(
|
||||
f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
|
||||
)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
|
||||
async def acreate_stream(self, **kwargs) -> AsyncSSEClient:
|
||||
"""async sse_invoke"""
|
||||
headers = cls.get_sse_header()
|
||||
return AsyncSSEClient(
|
||||
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
|
||||
)
|
||||
headers = self._default_headers
|
||||
return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
|
|
@ -16,7 +15,7 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -31,57 +30,52 @@ class ZhiPuEvent(Enum):
|
|||
FINISH = "finish"
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.ZHIPUAI)
|
||||
@register_provider(LLMType.ZHIPUAI)
|
||||
class ZhiPuAILLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
From now, support glm-3-turbo、glm-4, and also system_prompt.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_zhipuai(CONFIG)
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.__init_zhipuai(config)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
self.config = config
|
||||
|
||||
def __init_zhipuai(self, config: CONFIG):
|
||||
assert config.zhipuai_api_key
|
||||
zhipuai.api_key = config.zhipuai_api_key
|
||||
def __init_zhipuai(self, config: LLMConfig):
|
||||
assert config.api_key
|
||||
zhipuai.api_key = config.api_key
|
||||
# due to use openai sdk, set the api_key but it will't be used.
|
||||
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
if config.openai_proxy:
|
||||
if config.proxy:
|
||||
# FIXME: openai v1.x sdk has no proxy support
|
||||
openai.proxy = config.openai_proxy
|
||||
openai.proxy = config.proxy
|
||||
|
||||
def _const_kwargs(self, messages: list[dict]) -> dict:
|
||||
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
assert assist_msg["role"] == "assistant"
|
||||
return assist_msg.get("content")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = self.llm.invoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
return resp.model_dump()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
resp = await self.llm.acreate(**self._const_kwargs(messages))
|
||||
usage = resp.get("usage", {})
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
|
|
@ -89,35 +83,18 @@ class ZhiPuAILLM(BaseLLM):
|
|||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
|
||||
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for event in response.async_events():
|
||||
if event.event == ZhiPuEvent.ADD.value:
|
||||
content = event.data
|
||||
async for chunk in response.stream():
|
||||
finish_reason = chunk.get("choices")[0].get("finish_reason")
|
||||
if finish_reason == "stop":
|
||||
usage = chunk.get("usage", {})
|
||||
else:
|
||||
content = self.get_choice_delta_text(chunk)
|
||||
collected_content.append(content)
|
||||
log_llm_stream(content)
|
||||
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
|
||||
content = event.data
|
||||
logger.error(f"event error: {content}", end="")
|
||||
elif event.event == ZhiPuEvent.FINISH.value:
|
||||
"""
|
||||
event.meta
|
||||
{
|
||||
"task_status":"SUCCESS",
|
||||
"usage":{
|
||||
"completion_tokens":351,
|
||||
"prompt_tokens":595,
|
||||
"total_tokens":946
|
||||
},
|
||||
"task_id":"xx",
|
||||
"request_id":"xxx"
|
||||
}
|
||||
"""
|
||||
meta = json.loads(event.meta)
|
||||
usage = meta.get("usage")
|
||||
else:
|
||||
print(f"zhipuapi else event: {event.data}", end="")
|
||||
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class Architect(Role):
|
|||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
# Initialize actions specific to the Architect role
|
||||
self._init_actions([WriteDesign])
|
||||
self.set_actions([WriteDesign])
|
||||
|
||||
# Set events or actions the Architect should watch or be aware of
|
||||
self._watch({WritePRD})
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.skill_loader import SkillsDeclaration
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
|
|
@ -48,7 +47,8 @@ class Assistant(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.constraints = self.constraints.format(language=kwargs.get("language") or CONFIG.language or "Chinese")
|
||||
language = kwargs.get("language") or self.context.kwargs.language
|
||||
self.constraints = self.constraints.format(language=language)
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""Everything will be done part by part."""
|
||||
|
|
@ -56,16 +56,16 @@ class Assistant(Role):
|
|||
if not last_talk:
|
||||
return False
|
||||
if not self.skills:
|
||||
skill_path = Path(CONFIG.SKILL_PATH) if CONFIG.SKILL_PATH else None
|
||||
skill_path = Path(self.context.kwargs.SKILL_PATH) if self.context.kwargs.SKILL_PATH else None
|
||||
self.skills = await SkillsDeclaration.load(skill_yaml_file_name=skill_path)
|
||||
|
||||
prompt = ""
|
||||
skills = self.skills.get_skill_list()
|
||||
skills = self.skills.get_skill_list(context=self.context)
|
||||
for desc, name in skills.items():
|
||||
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
|
||||
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
|
||||
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
|
||||
rsp = await self.llm.aask(prompt, [])
|
||||
rsp = await self.llm.aask(prompt, ["You are an action classifier"])
|
||||
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
|
||||
return await self._plan(rsp, last_talk=last_talk)
|
||||
|
||||
|
|
@ -97,8 +97,8 @@ class Assistant(Role):
|
|||
async def talk_handler(self, text, **kwargs) -> bool:
|
||||
history = self.memory.history_text
|
||||
text = kwargs.get("last_talk") or text
|
||||
self.rc.todo = TalkAction(
|
||||
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
|
||||
self.set_todo(
|
||||
TalkAction(i_context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm)
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
@ -108,11 +108,11 @@ class Assistant(Role):
|
|||
if not skill:
|
||||
logger.info(f"skill not found: {text}")
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk, **kwargs)
|
||||
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk)
|
||||
await action.run(**kwargs)
|
||||
if action.args is None:
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
self.rc.todo = SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description)
|
||||
self.set_todo(SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description))
|
||||
return True
|
||||
|
||||
async def refine_memory(self) -> str:
|
||||
|
|
|
|||
|
|
@ -20,23 +20,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
|
||||
from metagpt.const import (
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
CODE_SUMMARIES_PDF_FILE_REPO,
|
||||
CODE_PLAN_AND_CHANGE_FILE_REPO,
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
REQUIREMENT_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import (
|
||||
CodePlanAndChangeContext,
|
||||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
Document,
|
||||
|
|
@ -80,12 +84,13 @@ class Engineer(Role):
|
|||
code_todos: list = []
|
||||
summarize_todos: list = []
|
||||
next_todo_action: str = ""
|
||||
n_summarize: int = 0
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([WriteCode])
|
||||
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug])
|
||||
self.set_actions([WriteCode])
|
||||
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug, WriteCodePlanAndChange])
|
||||
self.code_todos = []
|
||||
self.summarize_todos = []
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
|
|
@ -93,11 +98,10 @@ class Engineer(Role):
|
|||
@staticmethod
|
||||
def _parse_tasks(task_msg: Document) -> list[str]:
|
||||
m = json.loads(task_msg.content)
|
||||
return m.get("Task list")
|
||||
return m.get(TASK_LIST.key) or m.get(REFINED_TASK_LIST.key)
|
||||
|
||||
async def _act_sp_with_cr(self, review=False) -> Set[str]:
|
||||
changed_files = set()
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
for todo in self.code_todos:
|
||||
"""
|
||||
# Select essential information from the historical data to reduce the length of the prompt (summarized from human experience):
|
||||
|
|
@ -109,12 +113,16 @@ class Engineer(Role):
|
|||
coding_context = await todo.run()
|
||||
# Code review
|
||||
if review:
|
||||
action = WriteCodeReview(context=coding_context, llm=self.llm)
|
||||
self._init_action_system_message(action)
|
||||
action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm)
|
||||
self._init_action(action)
|
||||
coding_context = await action.run()
|
||||
await src_file_repo.save(
|
||||
coding_context.filename,
|
||||
dependencies={coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path},
|
||||
|
||||
dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path}
|
||||
if self.config.inc:
|
||||
dependencies.add(os.path.join(CODE_PLAN_AND_CHANGE_FILE_REPO, CODE_PLAN_AND_CHANGE_FILENAME))
|
||||
await self.project_repo.srcs.save(
|
||||
filename=coding_context.filename,
|
||||
dependencies=dependencies,
|
||||
content=coding_context.code_doc.content,
|
||||
)
|
||||
msg = Message(
|
||||
|
|
@ -134,6 +142,9 @@ class Engineer(Role):
|
|||
"""Determines the mode of action based on whether code review is used."""
|
||||
if self.rc.todo is None:
|
||||
return None
|
||||
if isinstance(self.rc.todo, WriteCodePlanAndChange):
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
return await self._act_code_plan_and_change()
|
||||
if isinstance(self.rc.todo, WriteCode):
|
||||
self.next_todo_action = any_to_name(SummarizeCode)
|
||||
return await self._act_write_code()
|
||||
|
|
@ -153,34 +164,32 @@ class Engineer(Role):
|
|||
)
|
||||
|
||||
async def _act_summarize(self):
|
||||
code_summaries_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO)
|
||||
code_summaries_pdf_file_repo = CONFIG.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO)
|
||||
tasks = []
|
||||
src_relative_path = CONFIG.src_workspace.relative_to(CONFIG.git_repo.workdir)
|
||||
for todo in self.summarize_todos:
|
||||
summary = await todo.run()
|
||||
summary_filename = Path(todo.context.design_filename).with_suffix(".md").name
|
||||
dependencies = {todo.context.design_filename, todo.context.task_filename}
|
||||
for filename in todo.context.codes_filenames:
|
||||
rpath = src_relative_path / filename
|
||||
summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name
|
||||
dependencies = {todo.i_context.design_filename, todo.i_context.task_filename}
|
||||
for filename in todo.i_context.codes_filenames:
|
||||
rpath = self.project_repo.src_relative_path / filename
|
||||
dependencies.add(str(rpath))
|
||||
await code_summaries_pdf_file_repo.save(
|
||||
await self.project_repo.resources.code_summary.save(
|
||||
filename=summary_filename, content=summary, dependencies=dependencies
|
||||
)
|
||||
is_pass, reason = await self._is_pass(summary)
|
||||
if not is_pass:
|
||||
todo.context.reason = reason
|
||||
tasks.append(todo.context.dict())
|
||||
await code_summaries_file_repo.save(
|
||||
filename=Path(todo.context.design_filename).name,
|
||||
content=todo.context.model_dump_json(),
|
||||
todo.i_context.reason = reason
|
||||
tasks.append(todo.i_context.model_dump())
|
||||
|
||||
await self.project_repo.docs.code_summary.save(
|
||||
filename=Path(todo.i_context.design_filename).name,
|
||||
content=todo.i_context.model_dump_json(),
|
||||
dependencies=dependencies,
|
||||
)
|
||||
else:
|
||||
await code_summaries_file_repo.delete(filename=Path(todo.context.design_filename).name)
|
||||
await self.project_repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name)
|
||||
|
||||
logger.info(f"--max-auto-summarize-code={CONFIG.max_auto_summarize_code}")
|
||||
if not tasks or CONFIG.max_auto_summarize_code == 0:
|
||||
logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}")
|
||||
if not tasks or self.config.max_auto_summarize_code == 0:
|
||||
return Message(
|
||||
content="",
|
||||
role=self.profile,
|
||||
|
|
@ -190,11 +199,39 @@ class Engineer(Role):
|
|||
)
|
||||
# The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited.
|
||||
# This parameter is used for debugging the workflow.
|
||||
CONFIG.max_auto_summarize_code -= 1 if CONFIG.max_auto_summarize_code > 0 else 0
|
||||
self.n_summarize += 1 if self.config.max_auto_summarize_code > self.n_summarize else 0
|
||||
return Message(
|
||||
content=json.dumps(tasks), role=self.profile, cause_by=SummarizeCode, send_to=self, sent_from=self
|
||||
)
|
||||
|
||||
async def _act_code_plan_and_change(self):
|
||||
"""Write code plan and change that guides subsequent WriteCode and WriteCodeReview"""
|
||||
logger.info("Writing code plan and change..")
|
||||
node = await self.rc.todo.run()
|
||||
code_plan_and_change = node.instruct_content.model_dump_json()
|
||||
dependencies = {
|
||||
REQUIREMENT_FILENAME,
|
||||
self.rc.todo.i_context.prd_filename,
|
||||
self.rc.todo.i_context.design_filename,
|
||||
self.rc.todo.i_context.task_filename,
|
||||
}
|
||||
await self.project_repo.docs.code_plan_and_change.save(
|
||||
filename=self.rc.todo.i_context.filename, content=code_plan_and_change, dependencies=dependencies
|
||||
)
|
||||
await self.project_repo.resources.code_plan_and_change.save(
|
||||
filename=Path(self.rc.todo.i_context.filename).with_suffix(".md").name,
|
||||
content=node.content,
|
||||
dependencies=dependencies,
|
||||
)
|
||||
|
||||
return Message(
|
||||
content=code_plan_and_change,
|
||||
role=self.profile,
|
||||
cause_by=WriteCodePlanAndChange,
|
||||
send_to=self,
|
||||
sent_from=self,
|
||||
)
|
||||
|
||||
async def _is_pass(self, summary) -> (str, str):
|
||||
rsp = await self.llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
|
||||
logger.info(rsp)
|
||||
|
|
@ -203,13 +240,18 @@ class Engineer(Role):
|
|||
return False, rsp
|
||||
|
||||
async def _think(self) -> Action | None:
|
||||
if not CONFIG.src_workspace:
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
|
||||
write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug])
|
||||
if not self.src_workspace:
|
||||
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
write_plan_and_change_filters = any_to_str_set([WriteTasks])
|
||||
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode, FixBug])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
if not self.rc.news:
|
||||
return None
|
||||
msg = self.rc.news[0]
|
||||
if self.config.inc and msg.cause_by in write_plan_and_change_filters:
|
||||
logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}")
|
||||
await self._new_code_plan_and_change_action()
|
||||
return self.rc.todo
|
||||
if msg.cause_by in write_code_filters:
|
||||
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
|
||||
await self._new_code_actions(bug_fix=msg.cause_by == any_to_str(FixBug))
|
||||
|
|
@ -220,60 +262,54 @@ class Engineer(Role):
|
|||
return self.rc.todo
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _new_coding_context(
|
||||
filename, src_file_repo, task_file_repo, design_file_repo, dependency
|
||||
) -> CodingContext:
|
||||
old_code_doc = await src_file_repo.get(filename)
|
||||
async def _new_coding_context(self, filename, dependency) -> CodingContext:
|
||||
old_code_doc = await self.project_repo.srcs.get(filename)
|
||||
if not old_code_doc:
|
||||
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content="")
|
||||
old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="")
|
||||
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
|
||||
task_doc = None
|
||||
design_doc = None
|
||||
for i in dependencies:
|
||||
if str(i.parent) == TASK_FILE_REPO:
|
||||
task_doc = await task_file_repo.get(i.name)
|
||||
task_doc = await self.project_repo.docs.task.get(i.name)
|
||||
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
|
||||
design_doc = await design_file_repo.get(i.name)
|
||||
design_doc = await self.project_repo.docs.system_design.get(i.name)
|
||||
if not task_doc or not design_doc:
|
||||
logger.error(f'Detected source code "{filename}" from an unknown origin.')
|
||||
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
|
||||
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
async def _new_coding_doc(filename, src_file_repo, task_file_repo, design_file_repo, dependency):
|
||||
context = await Engineer._new_coding_context(
|
||||
filename, src_file_repo, task_file_repo, design_file_repo, dependency
|
||||
)
|
||||
async def _new_coding_doc(self, filename, dependency):
|
||||
context = await self._new_coding_context(filename, dependency)
|
||||
coding_doc = Document(
|
||||
root_path=str(src_file_repo.root_path), filename=filename, content=context.model_dump_json()
|
||||
root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json()
|
||||
)
|
||||
return coding_doc
|
||||
|
||||
async def _new_code_actions(self, bug_fix=False):
|
||||
# Prepare file repos
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
changed_src_files = src_file_repo.all_files if bug_fix else src_file_repo.changed_files
|
||||
task_file_repo = CONFIG.git_repo.new_file_repository(TASK_FILE_REPO)
|
||||
changed_task_files = task_file_repo.changed_files
|
||||
design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
|
||||
changed_src_files = self.project_repo.srcs.all_files if bug_fix else self.project_repo.srcs.changed_files
|
||||
changed_task_files = self.project_repo.docs.task.changed_files
|
||||
changed_files = Documents()
|
||||
# Recode caused by upstream changes.
|
||||
for filename in changed_task_files:
|
||||
design_doc = await design_file_repo.get(filename)
|
||||
task_doc = await task_file_repo.get(filename)
|
||||
design_doc = await self.project_repo.docs.system_design.get(filename)
|
||||
task_doc = await self.project_repo.docs.task.get(filename)
|
||||
task_list = self._parse_tasks(task_doc)
|
||||
for task_filename in task_list:
|
||||
old_code_doc = await src_file_repo.get(task_filename)
|
||||
old_code_doc = await self.project_repo.srcs.get(task_filename)
|
||||
if not old_code_doc:
|
||||
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=task_filename, content="")
|
||||
old_code_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path), filename=task_filename, content=""
|
||||
)
|
||||
context = CodingContext(
|
||||
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
|
||||
)
|
||||
coding_doc = Document(
|
||||
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.model_dump_json()
|
||||
root_path=str(self.project_repo.src_relative_path),
|
||||
filename=task_filename,
|
||||
content=context.model_dump_json(),
|
||||
)
|
||||
if task_filename in changed_files.docs:
|
||||
logger.warning(
|
||||
|
|
@ -281,41 +317,44 @@ class Engineer(Role):
|
|||
f"{changed_files.docs[task_filename].model_dump_json()}"
|
||||
)
|
||||
changed_files.docs[task_filename] = coding_doc
|
||||
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
|
||||
self.code_todos = [
|
||||
WriteCode(i_context=i, context=self.context, llm=self.llm) for i in changed_files.docs.values()
|
||||
]
|
||||
# Code directly modified by the user.
|
||||
dependency = await CONFIG.git_repo.get_dependency()
|
||||
dependency = await self.git_repo.get_dependency()
|
||||
for filename in changed_src_files:
|
||||
if filename in changed_files.docs:
|
||||
continue
|
||||
coding_doc = await self._new_coding_doc(
|
||||
filename=filename,
|
||||
src_file_repo=src_file_repo,
|
||||
task_file_repo=task_file_repo,
|
||||
design_file_repo=design_file_repo,
|
||||
dependency=dependency,
|
||||
)
|
||||
coding_doc = await self._new_coding_doc(filename=filename, dependency=dependency)
|
||||
changed_files.docs[filename] = coding_doc
|
||||
self.code_todos.append(WriteCode(context=coding_doc, llm=self.llm))
|
||||
self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm))
|
||||
|
||||
if self.code_todos:
|
||||
self.rc.todo = self.code_todos[0]
|
||||
self.set_todo(self.code_todos[0])
|
||||
|
||||
async def _new_summarize_actions(self):
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
src_files = src_file_repo.all_files
|
||||
src_files = self.project_repo.srcs.all_files
|
||||
# Generate a SummarizeCode action for each pair of (system_design_doc, task_doc).
|
||||
summarizations = defaultdict(list)
|
||||
for filename in src_files:
|
||||
dependencies = await src_file_repo.get_dependency(filename=filename)
|
||||
ctx = CodeSummarizeContext.loads(filenames=dependencies)
|
||||
dependencies = await self.project_repo.srcs.get_dependency(filename=filename)
|
||||
ctx = CodeSummarizeContext.loads(filenames=list(dependencies))
|
||||
summarizations[ctx].append(filename)
|
||||
for ctx, filenames in summarizations.items():
|
||||
ctx.codes_filenames = filenames
|
||||
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self.llm))
|
||||
self.summarize_todos.append(SummarizeCode(i_context=ctx, context=self.context, llm=self.llm))
|
||||
if self.summarize_todos:
|
||||
self.rc.todo = self.summarize_todos[0]
|
||||
self.set_todo(self.summarize_todos[0])
|
||||
|
||||
async def _new_code_plan_and_change_action(self):
|
||||
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""
|
||||
files = self.project_repo.all_files
|
||||
requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME)
|
||||
requirement = requirement_doc.content if requirement_doc else ""
|
||||
code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, requirement=requirement)
|
||||
self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm)
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
def action_description(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.next_todo_action
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class InvoiceOCRAssistant(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([InvoiceOCR])
|
||||
self.set_actions([InvoiceOCR])
|
||||
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
|
|
@ -82,12 +82,12 @@ class InvoiceOCRAssistant(Role):
|
|||
resp = await todo.run(file_path)
|
||||
if len(resp) == 1:
|
||||
# Single file support for questioning based on OCR recognition results
|
||||
self._init_actions([GenerateTable, ReplyQuestion])
|
||||
self.set_actions([GenerateTable, ReplyQuestion])
|
||||
self.orc_data = resp[0]
|
||||
else:
|
||||
self._init_actions([GenerateTable])
|
||||
self.set_actions([GenerateTable])
|
||||
|
||||
self.rc.todo = None
|
||||
self.set_todo(None)
|
||||
content = INVOICE_OCR_SUCCESS
|
||||
resp = OCRResults(ocr_result=json.dumps(resp))
|
||||
msg = Message(content=content, instruct_content=resp)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.utils.common import any_to_name
|
||||
|
||||
|
|
@ -34,24 +33,19 @@ class ProductManager(Role):
|
|||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([PrepareDocuments, WritePRD])
|
||||
self.set_actions([PrepareDocuments, WritePRD])
|
||||
self._watch([UserRequirement, PrepareDocuments])
|
||||
self.todo_action = any_to_name(PrepareDocuments)
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if CONFIG.git_repo and not CONFIG.git_reinit:
|
||||
if self.git_repo and not self.config.git_reinit:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
CONFIG.git_reinit = False
|
||||
self.config.git_reinit = False
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
return bool(self.rc.todo)
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
return await super()._observe(ignore_memory=True)
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.todo_action
|
||||
|
|
|
|||
|
|
@ -33,5 +33,5 @@ class ProjectManager(Role):
|
|||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([WriteTasks])
|
||||
self.set_actions([WriteTasks])
|
||||
self._watch([WriteDesign])
|
||||
|
|
|
|||
|
|
@ -15,20 +15,13 @@
|
|||
of SummarizeCode.
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
MESSAGE_ROUTE_TO_NONE,
|
||||
TEST_CODES_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_NONE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Document, Message, RunCodeContext, TestingContext
|
||||
from metagpt.utils.common import any_to_str_set, parse_recipient
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
class QaEngineer(Role):
|
||||
|
|
@ -36,7 +29,8 @@ class QaEngineer(Role):
|
|||
profile: str = "QaEngineer"
|
||||
goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs"
|
||||
constraints: str = (
|
||||
"The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain"
|
||||
"The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain."
|
||||
"Use same language as user requirement"
|
||||
)
|
||||
test_round_allowed: int = 5
|
||||
test_round: int = 0
|
||||
|
|
@ -46,34 +40,35 @@ class QaEngineer(Role):
|
|||
|
||||
# FIXME: a bit hack here, only init one action to circumvent _think() logic,
|
||||
# will overwrite _think() in future updates
|
||||
self._init_actions([WriteTest])
|
||||
self.set_actions([WriteTest])
|
||||
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
|
||||
self.test_round = 0
|
||||
|
||||
async def _write_test(self, message: Message) -> None:
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
|
||||
changed_files = set(src_file_repo.changed_files.keys())
|
||||
# Unit tests only.
|
||||
if CONFIG.reqa_file and CONFIG.reqa_file not in changed_files:
|
||||
changed_files.add(CONFIG.reqa_file)
|
||||
tests_file_repo = CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO)
|
||||
if self.config.reqa_file and self.config.reqa_file not in changed_files:
|
||||
changed_files.add(self.config.reqa_file)
|
||||
for filename in changed_files:
|
||||
# write tests
|
||||
if not filename or "test" in filename:
|
||||
continue
|
||||
code_doc = await src_file_repo.get(filename)
|
||||
test_doc = await tests_file_repo.get("test_" + code_doc.filename)
|
||||
if not code_doc:
|
||||
continue
|
||||
if not code_doc.filename.endswith(".py"):
|
||||
continue
|
||||
test_doc = await self.project_repo.tests.get("test_" + code_doc.filename)
|
||||
if not test_doc:
|
||||
test_doc = Document(
|
||||
root_path=str(tests_file_repo.root_path), filename="test_" + code_doc.filename, content=""
|
||||
root_path=str(self.project_repo.tests.root_path), filename="test_" + code_doc.filename, content=""
|
||||
)
|
||||
logger.info(f"Writing {test_doc.filename}..")
|
||||
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
|
||||
context = await WriteTest(context=context, llm=self.llm).run()
|
||||
await tests_file_repo.save(
|
||||
filename=context.test_doc.filename,
|
||||
content=context.test_doc.content,
|
||||
dependencies={context.code_doc.root_relative_path},
|
||||
context = await WriteTest(i_context=context, context=self.context, llm=self.llm).run()
|
||||
await self.project_repo.tests.save_doc(
|
||||
doc=context.test_doc, dependencies={context.code_doc.root_relative_path}
|
||||
)
|
||||
|
||||
# prepare context for run tests in next round
|
||||
|
|
@ -81,8 +76,8 @@ class QaEngineer(Role):
|
|||
command=["python", context.test_doc.root_relative_path],
|
||||
code_filename=context.code_doc.filename,
|
||||
test_filename=context.test_doc.filename,
|
||||
working_directory=str(CONFIG.git_repo.workdir),
|
||||
additional_python_paths=[str(CONFIG.src_workspace)],
|
||||
working_directory=str(self.project_repo.workdir),
|
||||
additional_python_paths=[str(self.context.src_workspace)],
|
||||
)
|
||||
self.publish_message(
|
||||
Message(
|
||||
|
|
@ -94,21 +89,23 @@ class QaEngineer(Role):
|
|||
)
|
||||
)
|
||||
|
||||
logger.info(f"Done {str(tests_file_repo.workdir)} generating.")
|
||||
logger.info(f"Done {str(self.project_repo.tests.workdir)} generating.")
|
||||
|
||||
async def _run_code(self, msg):
|
||||
run_code_context = RunCodeContext.loads(msg.content)
|
||||
src_doc = await CONFIG.git_repo.new_file_repository(CONFIG.src_workspace).get(run_code_context.code_filename)
|
||||
src_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
|
||||
run_code_context.code_filename
|
||||
)
|
||||
if not src_doc:
|
||||
return
|
||||
test_doc = await CONFIG.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(run_code_context.test_filename)
|
||||
test_doc = await self.project_repo.tests.get(run_code_context.test_filename)
|
||||
if not test_doc:
|
||||
return
|
||||
run_code_context.code = src_doc.content
|
||||
run_code_context.test_code = test_doc.content
|
||||
result = await RunCode(context=run_code_context, llm=self.llm).run()
|
||||
result = await RunCode(i_context=run_code_context, context=self.context, llm=self.llm).run()
|
||||
run_code_context.output_filename = run_code_context.test_filename + ".json"
|
||||
await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
|
||||
await self.project_repo.test_outputs.save(
|
||||
filename=run_code_context.output_filename,
|
||||
content=result.model_dump_json(),
|
||||
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
|
||||
|
|
@ -130,10 +127,8 @@ class QaEngineer(Role):
|
|||
|
||||
async def _debug_error(self, msg):
|
||||
run_code_context = RunCodeContext.loads(msg.content)
|
||||
code = await DebugError(context=run_code_context, llm=self.llm).run()
|
||||
await FileRepository.save_file(
|
||||
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
|
||||
)
|
||||
code = await DebugError(i_context=run_code_context, context=self.context, llm=self.llm).run()
|
||||
await self.project_repo.tests.save(filename=run_code_context.test_filename, content=code)
|
||||
run_code_context.output = None
|
||||
self.publish_message(
|
||||
Message(
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class Researcher(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions(
|
||||
self.set_actions(
|
||||
[CollectLinks(name=self.name), WebBrowseAndSummarize(name=self.name), ConductResearch(name=self.name)]
|
||||
)
|
||||
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
|
||||
|
|
@ -49,7 +49,7 @@ class Researcher(Role):
|
|||
if self.rc.state + 1 < len(self.states):
|
||||
self._set_state(self.rc.state + 1)
|
||||
else:
|
||||
self.rc.todo = None
|
||||
self.set_todo(None)
|
||||
return False
|
||||
|
||||
async def _act(self) -> Message:
|
||||
|
|
|
|||
|
|
@ -23,29 +23,21 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional, Set, Type
|
||||
from typing import Any, Iterable, Optional, Set, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.llm import LLM, HumanProvider
|
||||
from metagpt.context_mixin import ContextMixin
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.plan.planner import Planner
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin, Task, TaskResult
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
import_class,
|
||||
read_json_file,
|
||||
role_raise_decorator,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.provider import HumanProvider
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
|
||||
|
||||
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
|
||||
|
|
@ -113,7 +105,7 @@ class RoleContext(BaseModel):
|
|||
max_react_loop: int = 1
|
||||
|
||||
def check(self, role_id: str):
|
||||
# if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
|
||||
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
|
||||
# self.long_term_memory.recover_memory(role_id, self)
|
||||
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
|
||||
pass
|
||||
|
|
@ -128,10 +120,10 @@ class RoleContext(BaseModel):
|
|||
return self.memory.get()
|
||||
|
||||
|
||||
class Role(SerializationMixin, is_polymorphic_base=True):
|
||||
class Role(SerializationMixin, ContextMixin, BaseModel):
|
||||
"""Role/Agent"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
|
||||
|
||||
name: str = ""
|
||||
profile: str = ""
|
||||
|
|
@ -140,12 +132,18 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
desc: str = ""
|
||||
is_human: bool = False
|
||||
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message
|
||||
role_id: str = ""
|
||||
states: list[str] = []
|
||||
|
||||
# scenarios to set action system_prompt:
|
||||
# 1. `__init__` while using Role(actions=[...])
|
||||
# 2. add action to role while using `role.set_action(action)`
|
||||
# 3. set_todo while using `role.set_todo(action)`
|
||||
# 4. when role.system_prompt is being updated (e.g. by `role.system_prompt = "..."`)
|
||||
# Additional, if llm is not set, we will use role's llm
|
||||
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
|
||||
rc: RoleContext = Field(default_factory=RoleContext)
|
||||
subscription: set[str] = set()
|
||||
addresses: set[str] = set()
|
||||
planner: Planner = None
|
||||
|
||||
# builtin variables
|
||||
|
|
@ -154,27 +152,85 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
|
||||
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_subscription(self):
|
||||
if not self.subscription:
|
||||
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
|
||||
return self
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
# --- avoid PydanticUndefinedAnnotation name 'Environment' is not defined #
|
||||
from metagpt.environment import Environment
|
||||
|
||||
Environment
|
||||
# ------
|
||||
Role.model_rebuild()
|
||||
self.pydantic_rebuild_model()
|
||||
super().__init__(**data)
|
||||
|
||||
if self.is_human:
|
||||
self.llm = HumanProvider()
|
||||
self.llm = HumanProvider(None)
|
||||
|
||||
self._check_actions()
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
self._watch(data.get("watch") or [UserRequirement])
|
||||
|
||||
if self.latest_observed_msg:
|
||||
self.recovered = True
|
||||
|
||||
@staticmethod
|
||||
def pydantic_rebuild_model():
|
||||
"""Rebuild model to avoid `RecursionError: maximum recursion depth exceeded in comparison`"""
|
||||
from metagpt.environment import Environment
|
||||
|
||||
Environment
|
||||
Role.model_rebuild()
|
||||
|
||||
@property
|
||||
def todo(self) -> Action:
|
||||
"""Get action to do"""
|
||||
return self.rc.todo
|
||||
|
||||
def set_todo(self, value: Optional[Action]):
|
||||
"""Set action to do and update context"""
|
||||
if value:
|
||||
value.context = self.context
|
||||
self.rc.todo = value
|
||||
|
||||
@property
|
||||
def git_repo(self):
|
||||
"""Git repo"""
|
||||
return self.context.git_repo
|
||||
|
||||
@git_repo.setter
|
||||
def git_repo(self, value):
|
||||
self.context.git_repo = value
|
||||
|
||||
@property
|
||||
def src_workspace(self):
|
||||
"""Source workspace under git repo"""
|
||||
return self.context.src_workspace
|
||||
|
||||
@src_workspace.setter
|
||||
def src_workspace(self, value):
|
||||
self.context.src_workspace = value
|
||||
|
||||
@property
|
||||
def project_repo(self) -> ProjectRepo:
|
||||
project_repo = ProjectRepo(self.context.git_repo)
|
||||
return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
"""Prompt schema: json/markdown"""
|
||||
return self.config.prompt_schema
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return self.config.project_name
|
||||
|
||||
@project_name.setter
|
||||
def project_name(self, value):
|
||||
self.config.project_name = value
|
||||
|
||||
@property
|
||||
def project_path(self):
|
||||
return self.config.project_path
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_addresses(self):
|
||||
if not self.addresses:
|
||||
self.addresses = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
|
||||
return self
|
||||
|
||||
def _reset(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
|
|
@ -183,59 +239,32 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
def _setting(self):
|
||||
return f"{self.name}({self.profile})"
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = (
|
||||
SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}")
|
||||
if stg_path is None
|
||||
else stg_path
|
||||
)
|
||||
def _check_actions(self):
|
||||
"""Check actions and set llm and prefix for each action."""
|
||||
self.set_actions(self.actions)
|
||||
return self
|
||||
|
||||
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
|
||||
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
write_json_file(role_info_path, role_info)
|
||||
|
||||
self.rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Role":
|
||||
"""stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
role_info = read_json_file(role_info_path)
|
||||
|
||||
role_class_str = role_info.pop("role_class")
|
||||
module_name = role_info.pop("module_name")
|
||||
role_class = import_class(class_name=role_class_str, module_name=module_name)
|
||||
|
||||
role = role_class(**role_info) # initiate particular Role
|
||||
role.set_recovered(True) # set True to make a tag
|
||||
|
||||
role_memory = Memory.deserialize(stg_path)
|
||||
role.set_memory(role_memory)
|
||||
|
||||
return role
|
||||
|
||||
def _init_action_system_message(self, action: Action):
|
||||
def _init_action(self, action: Action):
|
||||
if not action.private_config:
|
||||
action.set_llm(self.llm, override=True)
|
||||
else:
|
||||
action.set_llm(self.llm, override=False)
|
||||
action.set_prefix(self._get_prefix())
|
||||
|
||||
def refresh_system_message(self):
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
def set_action(self, action: Action):
|
||||
"""Add action to the role."""
|
||||
self.set_actions([action])
|
||||
|
||||
def set_recovered(self, recovered: bool = False):
|
||||
self.recovered = recovered
|
||||
def set_actions(self, actions: list[Union[Action, Type[Action]]]):
|
||||
"""Add actions to the role.
|
||||
|
||||
def set_memory(self, memory: Memory):
|
||||
self.rc.memory = memory
|
||||
|
||||
def init_actions(self, actions):
|
||||
self._init_actions(actions)
|
||||
|
||||
def _init_actions(self, actions):
|
||||
Args:
|
||||
actions: list of Action classes or instances
|
||||
"""
|
||||
self._reset()
|
||||
for idx, action in enumerate(actions):
|
||||
for action in actions:
|
||||
if not isinstance(action, Action):
|
||||
## 默认初始化
|
||||
i = action(name="", llm=self.llm)
|
||||
i = action(context=self.context)
|
||||
else:
|
||||
if self.is_human and not isinstance(action.llm, HumanProvider):
|
||||
logger.warning(
|
||||
|
|
@ -244,9 +273,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
f"try passing in Action classes instead of initialized instances"
|
||||
)
|
||||
i = action
|
||||
self._init_action_system_message(i)
|
||||
self._init_action(i)
|
||||
self.actions.append(i)
|
||||
self.states.append(f"{idx}. {action}")
|
||||
self.states.append(f"{len(self.actions)}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
|
|
@ -284,33 +313,29 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
def is_watch(self, caused_by: str):
|
||||
return caused_by in self.rc.watch
|
||||
|
||||
def subscribe(self, tags: Set[str]):
|
||||
def set_addresses(self, addresses: Set[str]):
|
||||
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
|
||||
buffer to be further processed in _observe. By default, a Role subscribes Messages with a tag of its own name
|
||||
or profile.
|
||||
"""
|
||||
self.subscription = tags
|
||||
self.addresses = addresses
|
||||
if self.rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
self.rc.env.set_subscription(self, self.subscription)
|
||||
self.rc.env.set_addresses(self, self.addresses)
|
||||
|
||||
def _set_state(self, state: int):
|
||||
"""Update the current state."""
|
||||
self.rc.state = state
|
||||
logger.debug(f"actions={self.actions}, state={state}")
|
||||
self.rc.todo = self.actions[self.rc.state] if state >= 0 else None
|
||||
self.set_todo(self.actions[self.rc.state] if state >= 0 else None)
|
||||
|
||||
def set_env(self, env: "Environment"):
|
||||
"""Set the environment in which the role works. The role can talk to the environment and can also receive
|
||||
messages by observing."""
|
||||
self.rc.env = env
|
||||
if env:
|
||||
env.set_subscription(self, self.subscription)
|
||||
self.refresh_system_message() # add env message to system message
|
||||
|
||||
@property
|
||||
def action_count(self):
|
||||
"""Return number of action"""
|
||||
return len(self.actions)
|
||||
env.set_addresses(self, self.addresses)
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
self.set_actions(self.actions) # reset actions to update llm and prefix
|
||||
|
||||
def _get_prefix(self):
|
||||
"""Get the role prefix"""
|
||||
|
|
@ -323,7 +348,8 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
prefix += CONSTRAINT_TEMPLATE.format(**{"constraints": self.constraints})
|
||||
|
||||
if self.rc.env and self.rc.env.desc:
|
||||
other_role_names = ", ".join(self.rc.env.role_names())
|
||||
all_roles = self.rc.env.role_names()
|
||||
other_role_names = ", ".join([r for r in all_roles if r != self.name])
|
||||
env_desc = f"You are in {self.rc.env.desc} with roles({other_role_names})."
|
||||
prefix += env_desc
|
||||
return prefix
|
||||
|
|
@ -338,7 +364,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
|
||||
if self.recovered and self.rc.state >= 0:
|
||||
self._set_state(self.rc.state) # action to run from recovered state
|
||||
self.set_recovered(False) # avoid max_react_loop out of work
|
||||
self.recovered = False # avoid max_react_loop out of work
|
||||
return True
|
||||
|
||||
prompt = self._get_prefix()
|
||||
|
|
@ -436,7 +462,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
break
|
||||
# act
|
||||
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
|
||||
rsp = await self._act() # 这个rsp是否需要publish_message?
|
||||
rsp = await self._act()
|
||||
actions_taken += 1
|
||||
return rsp # return output from the last action
|
||||
|
||||
|
|
@ -495,6 +521,8 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
rsp = await self._act_by_order()
|
||||
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
rsp = await self._plan_and_act()
|
||||
else:
|
||||
raise ValueError(f"Unsupported react mode: {self.rc.react_mode}")
|
||||
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
|
||||
return rsp
|
||||
|
||||
|
|
@ -516,7 +544,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
if not msg.cause_by:
|
||||
msg.cause_by = UserRequirement
|
||||
self.put_message(msg)
|
||||
|
||||
if not await self._observe():
|
||||
# If there is no new information, suspend and wait
|
||||
logger.debug(f"{self._setting}: no news. waiting.")
|
||||
|
|
@ -525,7 +552,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
rsp = await self.react()
|
||||
|
||||
# Reset the next action to be taken.
|
||||
self.rc.todo = None
|
||||
self.set_todo(None)
|
||||
# Send the response message to the Environment object to have it relay the message to the subscribers.
|
||||
self.publish_message(rsp)
|
||||
return rsp
|
||||
|
|
@ -536,18 +563,34 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
return not self.rc.news and not self.rc.todo and self.rc.msg_buffer.empty()
|
||||
|
||||
async def think(self) -> Action:
|
||||
"""The exported `think` function"""
|
||||
"""
|
||||
Export SDK API, used by AgentStore RPC.
|
||||
The exported `think` function
|
||||
"""
|
||||
await self._observe() # For compatibility with the old version of the Agent.
|
||||
await self._think()
|
||||
return self.rc.todo
|
||||
|
||||
async def act(self) -> ActionOutput:
|
||||
"""The exported `act` function"""
|
||||
"""
|
||||
Export SDK API, used by AgentStore RPC.
|
||||
The exported `act` function
|
||||
"""
|
||||
msg = await self._act()
|
||||
return ActionOutput(content=msg.content, instruct_content=msg.instruct_content)
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
def action_description(self) -> str:
|
||||
"""
|
||||
Export SDK API, used by AgentStore RPC and Agent.
|
||||
AgentStore uses this attribute to display to the user what actions the current role should take.
|
||||
`Role` provides the default property, and this property should be overridden by children classes if necessary,
|
||||
as demonstrated by the `Engineer` class.
|
||||
"""
|
||||
if self.rc.todo:
|
||||
if self.rc.todo.desc:
|
||||
return self.rc.todo.desc
|
||||
return any_to_name(self.rc.todo)
|
||||
if self.actions:
|
||||
return any_to_name(self.actions[0])
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -38,5 +38,5 @@ class Sales(Role):
|
|||
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch)
|
||||
else:
|
||||
action = SearchAndSummarize()
|
||||
self._init_actions([action])
|
||||
self.set_actions([action])
|
||||
self._watch([UserRequirement])
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue